YasiiKB commited on
Commit
97aa5af
·
verified ·
1 Parent(s): d7207a4

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +32 -0
  2. .gitignore +12 -0
  3. README.md +251 -0
  4. assets/r3pmnet_overview.png +3 -0
  5. assets/sioux_cranfield.png +3 -0
  6. assets/sioux_scans.png +3 -0
  7. assets/success_cases.png +3 -0
  8. assets/teaser.png +3 -0
  9. config/default.yaml +24 -0
  10. config/eval.yaml +28 -0
  11. dataloader/README.md +71 -0
  12. dataloader/__init__.py +1 -0
  13. dataloader/args.txt +14 -0
  14. dataloader/data_dict_generator.py +119 -0
  15. dataloader/dataset_generator.py +258 -0
  16. dataloader/user_data.py +127 -0
  17. environment.yml +16 -0
  18. pyproject.toml +25 -0
  19. r3pm_net/__init__.py +5 -0
  20. r3pm_net/config_loader.py +164 -0
  21. r3pm_net/feature_extractor.py +8 -0
  22. r3pm_net/model.py +382 -0
  23. r3pm_net/paths.py +11 -0
  24. scripts/eval_modelnet40.py +335 -0
  25. scripts/eval_sioux_cranfield.py +302 -0
  26. scripts/eval_sioux_scans.py +341 -0
  27. scripts/modelnet40.sh +45 -0
  28. scripts/sioux_cranfield.sh +46 -0
  29. scripts/sioux_scans.sh +45 -0
  30. src/train.py +366 -0
  31. thirdparty/__init__.py +1 -0
  32. thirdparty/learning3d/data_utils/__init__.py +4 -0
  33. thirdparty/learning3d/data_utils/dataloaders.py +454 -0
  34. thirdparty/learning3d/data_utils/user_data.py +119 -0
  35. thirdparty/learning3d/examples/test_curvenet.py +118 -0
  36. thirdparty/learning3d/examples/test_dcp.py +139 -0
  37. thirdparty/learning3d/examples/test_deepgmr.py +144 -0
  38. thirdparty/learning3d/examples/test_flownet.py +113 -0
  39. thirdparty/learning3d/examples/test_masknet.py +159 -0
  40. thirdparty/learning3d/examples/test_masknet2.py +162 -0
  41. thirdparty/learning3d/examples/test_pcn.py +118 -0
  42. thirdparty/learning3d/examples/test_pcrnet.py +120 -0
  43. thirdparty/learning3d/examples/test_pnlk.py +121 -0
  44. thirdparty/learning3d/examples/test_pointconv.py +126 -0
  45. thirdparty/learning3d/examples/test_pointnet.py +121 -0
  46. thirdparty/learning3d/examples/test_prnet.py +126 -0
  47. thirdparty/learning3d/examples/test_rpmnet.py +120 -0
  48. thirdparty/learning3d/examples/train_PointNetLK.py +240 -0
  49. thirdparty/learning3d/examples/train_dcp.py +249 -0
  50. thirdparty/learning3d/examples/train_deepgmr.py +244 -0
.gitattributes CHANGED
@@ -33,3 +33,35 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/r3pmnet_overview.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/sioux_cranfield.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/sioux_scans.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/success_cases.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
41
+ thirdparty/learning3d/pretrained/exp_classifier/models/best_model_snap.t7 filter=lfs diff=lfs merge=lfs -text
42
+ thirdparty/learning3d/pretrained/exp_classifier/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
43
+ thirdparty/learning3d/pretrained/exp_classifier/models/best_ptnet_model.t7 filter=lfs diff=lfs merge=lfs -text
44
+ thirdparty/learning3d/pretrained/exp_curvenet/models/model.t7 filter=lfs diff=lfs merge=lfs -text
45
+ thirdparty/learning3d/pretrained/exp_dcp/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
46
+ thirdparty/learning3d/pretrained/exp_flownet/models/model.best.t7 filter=lfs diff=lfs merge=lfs -text
47
+ thirdparty/learning3d/pretrained/exp_ipcrnet/models/best_model_v1.t7 filter=lfs diff=lfs merge=lfs -text
48
+ thirdparty/learning3d/pretrained/exp_ipcrnet/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
49
+ thirdparty/learning3d/pretrained/exp_ipcrnet/models/best_ptnet_model.t7 filter=lfs diff=lfs merge=lfs -text
50
+ thirdparty/learning3d/pretrained/exp_masknet/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
51
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.01.t7 filter=lfs diff=lfs merge=lfs -text
52
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.6.t7 filter=lfs diff=lfs merge=lfs -text
53
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.7.t7 filter=lfs diff=lfs merge=lfs -text
54
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.8.t7 filter=lfs diff=lfs merge=lfs -text
55
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_0.9.t7 filter=lfs diff=lfs merge=lfs -text
56
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_100.t7 filter=lfs diff=lfs merge=lfs -text
57
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_200.t7 filter=lfs diff=lfs merge=lfs -text
58
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_300.t7 filter=lfs diff=lfs merge=lfs -text
59
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_400.t7 filter=lfs diff=lfs merge=lfs -text
60
+ thirdparty/learning3d/pretrained/exp_masknet2/models/best_model_500.t7 filter=lfs diff=lfs merge=lfs -text
61
+ thirdparty/learning3d/pretrained/exp_pcn/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
62
+ thirdparty/learning3d/pretrained/exp_pnlk/models/best_model_snap.t7 filter=lfs diff=lfs merge=lfs -text
63
+ thirdparty/learning3d/pretrained/exp_pnlk/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
64
+ thirdparty/learning3d/pretrained/exp_pnlk/models/best_ptnet_model.t7 filter=lfs diff=lfs merge=lfs -text
65
+ thirdparty/learning3d/pretrained/exp_prnet/models/best_model.t7 filter=lfs diff=lfs merge=lfs -text
66
+ thirdparty/learning3d/pretrained/exp_prnet/models/model.99.t7 filter=lfs diff=lfs merge=lfs -text
67
+ thirdparty/learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *:*Zone.Identifier
3
+ .ipynb_checkpoints/
4
+ *.ipynb
5
+
6
+ checkpoints/
7
+ data/
8
+ results/
9
+ registration_plys/
10
+ logs/
11
+ notebooks/
12
+ kernels/
README.md ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- # R3PM-Net
2
+
3
+
4
+
5
+ This repository contains the official implementation of the paper:
6
+
7
+ <p align="center">
8
+ <strong><a href="https://arxiv.org/abs/2604.05060">R3PM-Net: Real-time, Robust, Real-world Point Matching Network</a></strong><br>
9
+ <strong>(AI4RWC@CVPRW 2026 - Oral Presentation)</strong>
10
+ </p> -->
11
+
12
+ <p align="center">
13
+
14
+ <h1 align="center">R3PM-Net: Real-time, Robust, Real-world Point Matching Network</h1>
15
+ <p align="center"> <strong>AI4RWC@CVPRW 2026 - Oral Presentation</strong></p>
16
+ <h3 align="center"><a href="https://arxiv.org/abs/2604.05060">Paper</a> | <a href="https://yasiikb.github.io/R3PM-Net/">Project Page</a> | <a href="https://huggingface.co/datasets/YasiiKB/R3PM-Net">Dataset</a></h3>
17
+ <div align="center"></div>
18
+ </p>
19
+ <p align="center"> <img src="assets/r3pmnet_overview.png" width="95%"> </p>
20
+ <p align="left"><i>Figure 1. Overview of the R3PM-Net Architecture. R3PM-Net employs a global-aware feature extraction module with shared weights to learn geometric similarities across a full receptive field.</i></p>
21
+
22
+ ## Introduction
23
+
24
+ R3PM-Net is a lightweight, global-aware, object-level point matching network designed to bridge the gap between approaches trained and evaluated on clean, dense, synthetic and real-world industrial point cloud data by prioritizing both generalizability and real-time efficiency.
25
+
26
+ <p align="center"> <img src="assets/teaser.png" width="40%"> </p>
27
+ <p align="left"><i>Figure 2. Examples of R3PM-Net performance on the Sioux-Cranfield dataset.</i></p>
28
+
29
+ ### Datasets
30
+
31
+ We propose two datasets; **Sioux-Cranfield** and **Sioux-Scans**, to address the gap between synthetic datasets and real-world industrial data.
32
+
33
+ <p align="center">
34
+ <table>
35
+ <tr>
36
+ <td align="center">
37
+ <img src="assets/sioux_cranfield.png" height="250">
38
+ <br>
39
+ <sub><b>Sioux-Cranfield</b></sub>
40
+ </td>
41
+ <td align="center">
42
+ <img src="assets/sioux_scans.png" height="250">
43
+ <br>
44
+ <sub><b>Sioux-Scans</b></sub>
45
+ </td>
46
+ </tr>
47
+ </table>
48
+ </p>
49
+ <p align="left"><i>Figure 3. CAD models of the Sioux-Cranfield dataset (Left). The first six belong to the Cranfield Assembly benchmark and the rest are contributions of this paper (Sioux dataset). Sioux-Scans point cloud data (Right). Target (blue) and Source (yellow) point clouds for seven distinct objects.</i></p>
50
+
51
+ ## Environment Setup
52
+
53
+ ```bash
54
+ # 1. Create environment
55
+ conda env create -f environment.yml
56
+ conda activate r3pm_net
57
+
58
+ # Optionally, install the dependencies and run manually:
59
+ pip install -e .
60
+ ```
61
+
62
+ To run the evaluations, please refer to each method's repo to set up the environment:
63
+ [Predator](https://github.com/prs-eth/OverlapPredator),
64
+ [GeoTransformer](https://github.com/qinzheng93/geotransformer),
65
+ [LoGDesc](https://github.com/karim416/LoGDesc), and
66
+ [RegTR](https://github.com/yewzijian/regtr).
67
+
68
+ Everything must be installed into the **same** conda enviromnet.
69
+
70
+ ## Data Preparation
71
+
72
+ ### ModelNet40
73
+
74
+ Download the dataset from [ModelNet40](http://modelnet.cs.princeton.edu/ModelNet40.zip) and extract it to:
75
+
76
+ ```
77
+ data/ModelNet40
78
+ ```
79
+
80
+ To save time, download the downsampled ModelNet40 test set from [ModelNet40_Downsampled](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/down_sampled_modelnet40.zip) and put it in:
81
+
82
+ ```
83
+ data/down_sampled_modelnet40
84
+ ```
85
+
86
+ ### Sioux-Cranfield
87
+
88
+ Download the dataset from [Sioux_Cranfiled](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/sioux_cranfield.zip) and put it in:
89
+
90
+ ```
91
+ data/sioux_cranfield
92
+ ```
93
+
94
+ ### Sioux-Scans
95
+
96
+ Download the dataset from [Sioux_Scans](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/sioux_scans.zip) and put it in:
97
+
98
+ ```
99
+ data/sioux_scans
100
+ ```
101
+
102
+ ### Fine-tune
103
+
104
+ Download the pickle files (.pkl) from [here](https://huggingface.co/datasets/YasiiKB/R3PM-Net/blob/main/simulators.zip) and put them in:
105
+
106
+ ```
107
+ data/simulators
108
+ ```
109
+ These pickle files are created from a subset of the Sioux-Cranfield containing the "teeth", "cube", "lime" and "lego" CAD models. There are 320 point cloud pairs, with 80-20 train-test split.
110
+
111
+ Optionally, to create your own datasets, use the scripts in `dataloader`, refering to the README file in that directory.
112
+
113
+ ## Pre-trained Models
114
+
115
+ Please download the pretrained model of each method from their repo (links provided above) and follow their instructions as to where to put them.
116
+
117
+ We use RPMNet's pre-trained model (*clean-trained*) for our Zero-shot version. Download it from [here](https://github.com/vinits5/learning3d/tree/master/pretrained/exp_rpmnet/models) and put it in:
118
+
119
+ ```
120
+ checkpoints/
121
+ ```
122
+
123
+ *Note:* You need to fine-tune the model yourself (see bleow) to get the fine-tuned weights which then you can put in the same directory.
124
+
125
+ ## Folder Structure
126
+
127
+ ```text
128
+ r3pm_net/
129
+ ├── assets/
130
+ ├── config/
131
+ │ ├── default.yaml # Training defaults
132
+ │ └── eval.yaml # Paths for evaluation scripts
133
+ ├── checkpoints/ # Pre-trained models' weights
134
+ ├── data/
135
+ │ ├── down_sampled_modelnet40/
136
+ │ ├── ModelNet40/
137
+ │ ├── sioux_cranfield/
138
+ │ └── sioux_scans/
139
+ ├── dataloader/ # Dataset dict generation & loaders
140
+ ├── logs/ # Experiment logs
141
+ ├── r3pm_net/ # Core package (model, feature extractor, config)
142
+ ├── scripts/ # SLURM/Bash and evaluation scripts
143
+ │ ├── eval_modelnet40.py
144
+ │ ├── eval_sioux_cranfield.py
145
+ │ ├── eval_sioux_scans.py
146
+ │ ├── modelnet40.sh
147
+ │ ├── sioux_cranfield.sh
148
+ │ └── sioux_scans.sh
149
+ ├── src/
150
+ │ └── train.py # Training
151
+ ├── thirdparty/learning3d/ # learning3d (RPMNet, losses, ops, …)
152
+ ├── tools/ # Registration eval, metrics, visualization
153
+ ├── environment.yml
154
+ ├── pyproject.toml
155
+ └── README.md
156
+ ```
157
+
158
+ ## Train
159
+
160
+ To train the model using `data/simulators` or your own dataset run:
161
+ ```bash
162
+ python src/train.py
163
+ ```
164
+
165
+ ## Evaluation
166
+
167
+ Scripts are provided in `scripts/` to reproduce results.
168
+
169
+ **ModelNet40**
170
+
171
+ ```bash
172
+ bash scripts/modelnet40.sh
173
+ ```
174
+
175
+ **Sioux-Cranfield**
176
+
177
+ ```bash
178
+ bash scripts/sioux_cranfield.sh
179
+ ```
180
+
181
+ **Sioux-Scans**
182
+ This evaluates the proposed hybrid Coarse-to-Fine Registration approach.
183
+
184
+ ```bash
185
+ bash scripts/sioux_scans.sh
186
+ ```
187
+
188
+ ### Manual Execution
189
+
190
+ For example for evaluation on `Sioux-Cranfield`, run:
191
+
192
+ ```bash
193
+ python scripts/eval_sioux_cranfield.py
194
+ ```
195
+
196
+ ## Results
197
+ *IMPORTANT NOTE: Unfortunately, we cannot release the feature-extraction model and the fine-tuned weights. Therefore, to re-poduce these results you need to implement the feature extractor (based on the paper) and fine-tune it with the provided data.*
198
+
199
+ ### ModelNet40
200
+
201
+
202
+ | Method | RRE [°] ↓ | RTE [cm] ↓ | CD [cm] ↓ | Fitness ↑ | In. RMSE [cm] ↓ | Time [s] ↓ |
203
+ | ------------------- | ----------------- | ----------------- | ----------------- | ----------------- | ------------------ | ----------------- |
204
+ | RPMNet | 30.898 | **0.002** | 0.153 | *0.998* | 0.094 | *0.021* |
205
+ | Predator | 7.262 | 0.028 | *0.045* | **1.000** | *0.026* | 0.071 |
206
+ | GeoTransformer | 50.357 | 0.215 | 0.255 | 0.921 | 0.101 | 0.065 |
207
+ | RegTR | **1.712** | *0.007* | **0.017** | **1.000** | **0.009** | 0.045 |
208
+ | LoGDesc | 42.762 | 0.158 | 0.183 | 0.978 | 0.097 | 0.075 |
209
+ | **R3PM-Net (ours)** | *5.198* | 0.010 | 0.052 | **1.000** | 0.029 | **0.007** |
210
+
211
+
212
+ > **Notes:** **Best** results are in bold; *Second-best* results are underlined.
213
+
214
+ ### Sioux-Cranfield
215
+
216
+
217
+ | Method | RRE [°] ↓ | RTE [cm] ↓ | CD [cm] ↓ | Fitness ↑ | In. RMSE [cm] ↓ | Time [s] ↓ |
218
+ | ------------------- | ----------------- | ----------------- | ----------------- | ----------------- | ------------------ | ----------------- |
219
+ | RPMNet | 32.217 | **0.002** | 0.160 | *0.997* | 0.098 | 0.021 |
220
+ | Predator | 16.448 | 0.044 | 0.072 | **1.000** | 0.042 | 0.071 |
221
+ | GeoTrans. | 45.582 | 0.183 | 0.297 | 0.906 | 0.111 | 0.065 |
222
+ | RegTR | **1.311** | *0.004* | **0.023** | **1.000** | **0.012** | 0.045 |
223
+ | LoGDesc | 121.224 | 0.773 | 0.692 | 0.718 | 0.224 | 0.075 |
224
+ | **R3PM-Net (ours)** | *5.451* | 0.006 | *0.054* | **1.000** | *0.030* | **0.006** |
225
+
226
+
227
+ ### Sioux-Scans
228
+ <p align="center"> <img src="assets/success_cases.png" width="85%"> </p>
229
+
230
+ <p align="left"><i>Figure 4. Qualitative registration results of R3PM-Net on real-world event-camera data. It successfully aligns the "teeth" and "cube" models. The fine-tuned version also solves the "lime" and "house".</i></p>
231
+
232
+ ## Acknowledgement
233
+
234
+ We adapted some codes from some awesome repositories including [Learning3D](https://github.com/vinits5/learning3d) and [RPMNet](https://github.com/yewzijian/RPMNet). Thanks for making the codes publicly available.
235
+
236
+ ## Citation
237
+
238
+ If you find this repository useful, please consider citing:
239
+
240
+ ```bibtex
241
+ @misc{kashefbahrami2026r3pmnetrealtimerobustrealworld,
242
+ title={R3PM-Net: Real-time, Robust, Real-world Point Matching Network},
243
+ author={Yasaman Kashefbahrami and Erkut Akdag and Panagiotis Meletis and Evgeniya Balmashnova and Dip Goswami and Egor Bondarau},
244
+ year={2026},
245
+ eprint={2604.05060},
246
+ archivePrefix={arXiv},
247
+ primaryClass={cs.CV},
248
+ url={https://arxiv.org/abs/2604.05060},
249
+ }
250
+ ```
251
+
assets/r3pmnet_overview.png ADDED

Git LFS Details

  • SHA256: 40dac0c28fa7d4f1213ff2ab93493bfe30afc9ccc97097faa53e5e7fcdd2d057
  • Pointer size: 131 Bytes
  • Size of remote file: 279 kB
assets/sioux_cranfield.png ADDED

Git LFS Details

  • SHA256: 8b68d27b5ac14b43a2b814d2ad8d599786d5c4187ce90bf498fb9a270011f180
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
assets/sioux_scans.png ADDED

Git LFS Details

  • SHA256: fbd242ea53dfb82b967be8d3f1f9aa34090b72696034e0d7284760f75f82ad62
  • Pointer size: 131 Bytes
  • Size of remote file: 429 kB
assets/success_cases.png ADDED

Git LFS Details

  • SHA256: bd2cf529c20d11f3b04b345a69746635517ea37d1964650fc47b13394f8448a1
  • Pointer size: 131 Bytes
  • Size of remote file: 679 kB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: a20019ddf1e712d331fd56d251111b9b37f062c908409134aaf621b6758cef58
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
config/default.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default training and data paths for R3PM-Net
2
+
3
+ exp_name: exp_r3pmnet
4
+ eval: false
5
+ save_dir: ""
6
+
7
+ fine_tune_feature_extractor: tune
8
+ transfer_weights: "" # Optional: leave empty to skip loading
9
+ emb_dims: 1024
10
+ symfn: max
11
+
12
+ seed: 1234
13
+ workers: 4
14
+ batch_size: 5
15
+ epochs: 2
16
+ start_epoch: 0
17
+ optimizer: Adam
18
+ resume: ""
19
+ pretrained: ""
20
+ device: cuda:0
21
+
22
+ # Pickled Registration Dataset dicts (keys: template, source, transformation)
23
+ train_dict_path: data/simulators/data_dict_train.pkl
24
+ test_dict_path: data/simulators/data_dict_test.pkl
config/eval.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Paths for scripts/evaluation (loaded by scripts/*.py).
2
+
3
+ data_root: data
4
+ pretrained_rpmnet_dir: checkpoints
5
+
6
+ modelnet40:
7
+ dataset_path: data/ModelNet40
8
+ cache_dir: data/down_sampled_modelnet40
9
+
10
+ sioux:
11
+ base_dir: data
12
+
13
+ methods:
14
+ geotransformer:
15
+ root: GeoTransformer
16
+ exp_subdir: GeoTransformer/experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn
17
+ weights_path: GeoTransformer/weights/geotransformer-modelnet.pth.tar
18
+ predator:
19
+ root: OverlapPredator
20
+ config_path: OverlapPredator/configs/test/modelnet.yaml
21
+ weights_path: null
22
+ logdesc:
23
+ root: LoGDesc
24
+ weights_path: LoGDesc/pre-trained/best_model.pth
25
+ regtr:
26
+ root: RegTR
27
+ ckpt_path: RegTR/trained_models/modelnet/ckpt/model-best.pth
28
+ config_path: RegTR/trained_models/modelnet/config.yaml
dataloader/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloaders
2
+
3
+ This directory contains scripts to generate Simulated datasets for training and testing.
4
+ It uses functionalities from `tools` folder to:
5
+
6
+ `generate_dataset`
7
+
8
+ - load and downsample data
9
+ - compute normals if needed
10
+ - apply random transformations
11
+ - add augmentations (noise, outliers and occlusion)
12
+ - save point clouds
13
+
14
+ **Note: Two input point clouds must have same length.**
15
+
16
+ `generate_dataset_dict`
17
+
18
+ - save generated dataset in dictionaries suitable to train models (following Learning3d requirments)
19
+ - checks dimensions (to meet Learning3d requirments)
20
+
21
+ `combine_dataset_dict`
22
+
23
+ - shuffle and combine all generated dictionaries into one
24
+ - split train and test sets
25
+
26
+ ## How to generate datasets?
27
+
28
+ Modify the `args.txt` file to contain the correct paths and other specifications e.g. downsampling rate, noise level, etc. Other default arguments in `data_dict_generator.py` can also be changed.
29
+
30
+ 1. Generate transformed target point clouds + GT transforms
31
+ Change `--action in dataloader/args.txt` to `generate_dataset` and run:
32
+ ```
33
+ python dataloader/data_dict_generator.py @dataloader/args.txt
34
+ ```
35
+
36
+ 2. Generate train/test .pkl dicts
37
+ Change `--action in dataloader/args.txt` to `generate_dataset_dict`, then run the above script again.
38
+
39
+ 3. Combine multiple object dicts
40
+ Set `--action combine_dataset_dict` and run again to get train and test `dict.pkl` files.
41
+
42
+ ### Manual Run (without args.txt)
43
+ Optinally you can manually run:
44
+ ```
45
+ python dataloader/data_dict_generator.py \
46
+ --pcdPath /path/to/source_scan.pcd \
47
+ --cadPath /path/to/object.stl \
48
+ --name teeth \
49
+ --action generate_dataset \
50
+ --every_k_points 100 \
51
+ --num_transformation 50 \
52
+ --angles 0 90 180 \
53
+ --translation_range -1 1 \
54
+ --index 0 \
55
+ --noise_level 0 \
56
+ --outlier_level 0 \
57
+ --outlier_bounds -0.05 0.05 \
58
+ --occ_level 0 \
59
+ --save_path data/simulators
60
+ ```
61
+ then:
62
+ ```
63
+ python dataloader/data_dict_generator.py \
64
+ --pcdPath /path/to/source_scan.pcd \
65
+ --cadPath /path/to/object.stl \
66
+ --name teeth \
67
+ --action generate_dataset_dict \
68
+ --dataset_size 50 \
69
+ --index 0 \
70
+ --save_path data/simulators
71
+ ```
dataloader/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Dataset loaders and generators
dataloader/args.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --pcdPath data/sioux_scans/teeth_clean.ply
2
+ --cadPath data/sioux_cranfield/teeth.stl
3
+ --action combine_dataset_dict
4
+ --name teeth
5
+ --every_k_points 100
6
+ --num_transformation 50
7
+ --angles 0 90 180
8
+ --translation_range -1 1
9
+ --dataset_size 50
10
+ --index 0
11
+ --noise_level 0
12
+ --outlier_level 0
13
+ --outlier_bounds -0.05 0.05
14
+ --occ_level 0
dataloader/data_dict_generator.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import os
5
+ import shlex
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+
11
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ if str(_REPO_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(_REPO_ROOT))
14
+
15
+ from tools import data
16
+ from dataloader.dataset_generator import combine_dataset_dict, generate_dataset, generate_dataset_dict
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+
21
+ def main():
22
+ # Set up the argument parser
23
+ parser = argparse.ArgumentParser(description='Automate dataset generation and processing.')
24
+
25
+ # Define arguments (change these as needed)
26
+ parser.add_argument('--pcdPath', type=str, required=True, help='Path to the PCD file')
27
+ parser.add_argument('--cadPath', type=str, required=True, help='Path to the CAD file')
28
+ parser.add_argument('--action', type=str, choices=['generate_dataset', 'generate_dataset_dict', 'combine_dataset_dict'], required=True, help='Action to perform')
29
+ parser.add_argument('--compute_normals', action='store_true', help='Flag to compute normals')
30
+ parser.add_argument('--every_k_points', type=int, default=1, help='Sampling rate for points')
31
+ parser.add_argument('--save', action='store_true', help='Flag to save the generated dataset')
32
+ parser.add_argument(
33
+ '--save_path',
34
+ type=str,
35
+ default='data/simulators',
36
+ help='Directory to save generated datasets (relative to repo root if not absolute)',
37
+ )
38
+ parser.add_argument('--name', type=str, required=True, help='Name identifier for the dataset (e.g., teeth, cube, etc.)')
39
+
40
+ # Additional parameters for dataset generation (change these as needed)
41
+ parser.add_argument('--num_transformation', type=int, default=50, help='Number of transformations')
42
+ parser.add_argument('--angles', type=int, nargs='+', default=list(range(0, 360, 10)), help='Rotation angles')
43
+ parser.add_argument('--translation_range', type=float, nargs=2, default=(-1, 1), help='Translation range')
44
+ parser.add_argument('--dataset_size', type=int, default=400, help='Size of the dataset to generate')
45
+ parser.add_argument('--index', type=int, default=0, help='Index for dataset generation')
46
+ parser.add_argument('--noise_level', type=float, default=0, help='Noise level')
47
+ parser.add_argument('--outlier_level', type=float, default=0, help='Outlier level')
48
+ parser.add_argument('--outlier_bounds', type=float, nargs=2, default=(-10, 10), help='Outlier bounds')
49
+ parser.add_argument('--occ_level', type=float, default=0, help='Occlusion level')
50
+
51
+ # Parse the arguments
52
+
53
+ # Check if an argument file is being used
54
+ if sys.argv[1].startswith('@'):
55
+ args_file = sys.argv[1][1:] # Strip the '@' from the filename
56
+ with open(args_file, 'r') as file:
57
+ # Read and split arguments from the file
58
+ args = parser.parse_args(shlex.split(file.read()))
59
+ else:
60
+ args = parser.parse_args()
61
+
62
+ # Print out the arguments to verify
63
+ print(vars(args))
64
+
65
+ # Load the data
66
+ np.random.seed(42)
67
+ if args.compute_normals:
68
+ _, cad, _, cad_normals = data.load_data(args.pcdPath, args.cadPath, every_k_points=args.every_k_points, same_length=True, compute_normals=True)
69
+ suffix = '_with_normals'
70
+ else:
71
+ _, cad = data.load_data(args.pcdPath, args.cadPath, every_k_points=args.every_k_points, same_length=True)
72
+ cad_normals = None
73
+ suffix = ''
74
+ source = copy.deepcopy(cad)
75
+
76
+ rp = Path(args.save_path)
77
+ if not rp.is_absolute():
78
+ rp = _REPO_ROOT / args.save_path
79
+ ROOT_DIR = str(rp.resolve())
80
+ if not ROOT_DIR.endswith(os.sep):
81
+ ROOT_DIR += os.sep
82
+
83
+ # Perform the selected action
84
+ if args.action == 'generate_dataset':
85
+ logging.info('Generating dataset...')
86
+ generate_dataset(source, args.pcdPath, args.cadPath, args.num_transformation, args.angles, args.translation_range, args.index, args.noise_level, args.outlier_level, args.outlier_bounds, args.occ_level, save_dir=ROOT_DIR)
87
+
88
+ elif args.action == 'generate_dataset_dict':
89
+ logging.info('Generating dataset dictionary...')
90
+ output_train_file = f'{ROOT_DIR}data_dict_train_{args.name}{suffix}.pkl'
91
+ output_test_file = f'{ROOT_DIR}data_dict_test_{args.name}{suffix}.pkl'
92
+ generate_dataset_dict(source, args.dataset_size, args.index, output_train_file, output_test_file, cad_normals)
93
+
94
+ elif args.action == 'combine_dataset_dict':
95
+ logging.info('Combining dataset dictionaries...')
96
+ train_files = [
97
+ f'{ROOT_DIR}data_dict_train_teeth{suffix}.pkl'
98
+ # f'{ROOT_DIR}data_dict_train_elephant{suffix}.pkl',
99
+ # f'{ROOT_DIR}data_dict_train_house{suffix}.pkl',
100
+ # f'{ROOT_DIR}data_dict_train_shoe{suffix}.pkl'
101
+ ]
102
+
103
+ test_files = [
104
+ f'{ROOT_DIR}data_dict_test_teeth{suffix}.pkl'
105
+ # f'{ROOT_DIR}data_dict_test_elephant{suffix}.pkl',
106
+ # f'{ROOT_DIR}data_dict_test_house{suffix}.pkl',
107
+ # f'{ROOT_DIR}data_dict_test_shoe{suffix}.pkl'
108
+ ]
109
+
110
+ output_train_file = f'{ROOT_DIR}data_dict_train_{suffix}.pkl'
111
+ output_test_file = f'{ROOT_DIR}data_dict_test_{suffix}.pkl'
112
+
113
+ combine_dataset_dict(train_files, test_files, output_train_file, output_test_file)
114
+
115
+ else:
116
+ logging.warning('No valid action selected.')
117
+
118
+ if __name__ == '__main__':
119
+ main()
dataloader/dataset_generator.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import pickle
4
+ import random
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import open3d as o3
10
+
11
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ if str(_REPO_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(_REPO_ROOT))
14
+
15
+ from tools import augmentation, data, transformations
16
+
17
+ _SIM_DATA = _REPO_ROOT / "data" / "simulators"
18
+ '''
19
+ This module provides functions to generate a dataset of point clouds with random transformations, with options for noise, outliers, and occlusions.
20
+ It also includes functions to check the shape of the data and to generate a data dictionary for training and testing,
21
+ and a function to combine multiple dataset dictionaries.
22
+ '''
23
+
24
+ def generate_dataset(pcd, pcdPath, cadPath, num_transformation, angles, translation_range, index, noise_level = 0, outlier_level = 0, outlier_bounds = (-10, 10), occ_level = 0, save_dir=None):
25
+ '''
26
+ A function to generate a dataset of point clouds with random transformations.
27
+
28
+ Args:
29
+ pcd (open3d.geometry.PointCloud): The source point cloud
30
+ pcdPath (str): The path to the source point cloud
31
+ cadPath (str): The path to the target point cloud
32
+ num_transformation (int): The number of transformations to generate
33
+ angles (numpy.ndarray): The range of angles for the random transformations
34
+ translation_range (tuple): The range of translations for the random transformations
35
+ index (int): The index to start saving the generated dataset
36
+ noise_level (float): The level of noise to add to the point clouds
37
+ outlier_level (float): The level of outliers to add to the point clouds
38
+ occ_level (float): The level of occlusions to add to the point clouds
39
+ save (bool): A flag to save the generated dataset
40
+
41
+ Returns:
42
+ None
43
+ '''
44
+ np.random.seed(42)
45
+ target_list = []
46
+ gt_transformation_list = []
47
+
48
+ for i in range(num_transformation):
49
+ # Generate random gt transformation
50
+ x_angle= np.random.uniform(angles[0], angles[-1], size=1)
51
+ y_angle= np.random.uniform(angles[0], angles[-1], size=1)
52
+ z_angle= np.random.uniform(angles[0], angles[-1], size=1)
53
+ gt_transformation = transformations.create_transformation(x_angle, y_angle, z_angle, translation_range)
54
+
55
+ target = copy.deepcopy(pcd)
56
+ target.transform(gt_transformation)
57
+
58
+ if noise_level != 0:
59
+ target = augmentation.apply_noise(target, noise_level)
60
+ print('Noise applied')
61
+
62
+ if outlier_level != 0 or occ_level != 0:
63
+ _, another_cad = data.load_data(pcdPath, cadPath, every_k_points=1)
64
+ target = copy.deepcopy(another_cad).transform(gt_transformation)
65
+ if occ_level != 0:
66
+ target, _ = augmentation.apply_occlusion(target, occ_level)
67
+ print('Occlusion applied')
68
+ if outlier_level != 0:
69
+ target = augmentation.add_outliers(target, outlier_level, outlier_lowerbound=outlier_bounds[0], outlier_upperbound=outlier_bounds[1])
70
+ print('Outliers applied')
71
+
72
+ # randomly take points away from target to get to same length as source
73
+ if len(target.points) >= len(pcd.points):
74
+ np.random.seed(42)
75
+ target_points = np.asarray(target.points)
76
+ indices = np.random.choice(len(target_points), 1441, replace=False) # change len(source.points) to a specific num if you want to have a fixed number of points
77
+ sampled_points = target_points[indices]
78
+ target.points = o3.utility.Vector3dVector(sampled_points)
79
+ else:
80
+ print('Target has fewer points than source and can\'t be downsampled to the same length.')
81
+
82
+ print(f'size of source and target: {len(pcd.points)}, {len(target.points)}')
83
+ target_list.append(target)
84
+ gt_transformation_list.append(gt_transformation)
85
+
86
+ # Save the generated dataset
87
+ if save_dir is not None:
88
+ if not os.path.exists(save_dir):
89
+ os.makedirs(save_dir)
90
+
91
+ for i, (target, transformation) in enumerate(zip(target_list, gt_transformation_list)):
92
+ target_path = os.path.join(save_dir, f"target_{i+index}.pcd")
93
+ transformation_path = os.path.join(save_dir, f"transformation_{i+index}.npy")
94
+ o3.io.write_point_cloud(target_path, target)
95
+ np.save(transformation_path, transformation)
96
+
97
+ def check_shape(data, expected_shape_3d, expected_shape_6d):
98
+ return data.shape == expected_shape_3d or data.shape == expected_shape_6d
99
+
100
+ def generate_dataset_dict(source, dataset_size, index, output_train_file_path, output_test_file_path, source_normals = None):
101
+ '''
102
+ This function shuffles the dataset and generates a data_dict for the training and testing data following the pattern acceptable to Learning3D.
103
+
104
+ Args:
105
+ source (open3d.geometry.PointCloud): The source point cloud
106
+ dataset_size (int): The size of the dataset
107
+
108
+ Returns:
109
+ None
110
+ '''
111
+ np.random.seed(42)
112
+ transformed_pcds = []
113
+ gt_transformations = []
114
+
115
+ # Load the transformed point clouds and ground truth transformations
116
+ for i in range(index,index+dataset_size):
117
+ transformed_pcd = o3.io.read_point_cloud(str(_SIM_DATA / f"target_{i}.pcd"))
118
+ gt_transformation = np.load(str(_SIM_DATA / f"transformation_{i}.npy"))
119
+
120
+ if source_normals is not None: # we also need target normals
121
+ M = np.linalg.inv(gt_transformation).T
122
+ target_normals = np.dot(source_normals, M[:3,:3]) # transformed_normals = normals * (transformation)^-1.T
123
+ transformed_points = np.concatenate((np.asarray(transformed_pcd.points), target_normals), axis=1)
124
+ else:
125
+ transformed_points = np.asarray(transformed_pcd.points).astype(np.float32)
126
+
127
+ transformed_pcds.append(transformed_points)
128
+ gt_transformations.append(gt_transformation)
129
+
130
+ # Shuffle the transformed point clouds and ground truth transformations in the same way
131
+ temp = list(zip(transformed_pcds, gt_transformations))
132
+ random.shuffle(temp)
133
+ transformed_pcds, gt_transformations = zip(*temp)
134
+
135
+ # Convert lists to numpy arrays
136
+ transformed_pcds_np = np.array(transformed_pcds)
137
+ gt_transformations_np = np.array(gt_transformations)
138
+
139
+ if source_normals is not None:
140
+ source = np.concatenate((np.asarray(source.points), source_normals), axis=1)
141
+ else:
142
+ source = np.asarray(source.points).astype(np.float32)
143
+
144
+ data_dict = {
145
+ 'template': np.tile(source, (dataset_size, 1, 1)),
146
+ 'source': transformed_pcds_np,
147
+ 'transformation': gt_transformations_np
148
+ }
149
+
150
+ # Split the data_dict into training and testing data_dict
151
+ train_size = int(0.8 * dataset_size)
152
+ test_size = dataset_size - train_size
153
+ num_points = len(source)
154
+
155
+ data_dict_train = {}
156
+ data_dict_test = {}
157
+ for key in data_dict.keys():
158
+ data_dict_train[key] = data_dict[key][0:train_size]
159
+ data_dict_test[key] = data_dict[key][train_size:]
160
+
161
+ assert set(data_dict_train.keys()) == {'template', 'source', 'transformation'}
162
+ assert set(data_dict_test.keys()) == {'template', 'source', 'transformation'}
163
+
164
+ expected_shape_3d_train = (train_size, num_points, 3)
165
+ expected_shape_6d_train = (train_size, num_points, 6)
166
+
167
+ assert check_shape(data_dict_train['template'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {data_dict_train['template'].shape}"
168
+ assert check_shape(data_dict_train['source'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {data_dict_train['source'].shape}"
169
+ assert data_dict_train['transformation'].shape == (train_size, 4, 4), f"Expected shape: {(train_size, 4, 4)}, but got {data_dict_train['transformation'].shape}"
170
+
171
+ expected_shape_3d_test = (test_size, num_points, 3)
172
+ expected_shape_6d_test = (test_size, num_points, 6)
173
+
174
+ assert check_shape(data_dict_test['template'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {data_dict_test['template'].shape}"
175
+ assert check_shape(data_dict_test['source'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {data_dict_test['source'].shape}"
176
+ assert data_dict_test['transformation'].shape == (test_size, 4, 4), f"Expected shape: {(test_size, 4, 4)}, but got {data_dict_test['transformation'].shape}"
177
+
178
+ with open(output_train_file_path, 'wb') as f:
179
+ pickle.dump(data_dict_train, f)
180
+ print(f"train_dict saved to {output_train_file_path}")
181
+
182
+ with open(output_test_file_path, 'wb') as f:
183
+ pickle.dump(data_dict_test, f)
184
+ print(f"test_dict saved to {output_test_file_path}")
185
+
186
+
187
+ def combine_dataset_dict(train_files, test_files, output_train_file_path, output_test_file_path):
188
+ '''
189
+ Combine and shuffle dictionaries from multiple files.
190
+
191
+ Args:
192
+ train_files (list of str): List of file paths to training dictionaries.
193
+ test_files (list of str): List of file paths to testing dictionaries.
194
+ output_train_file (str): Output file path for the combined training dictionary.
195
+ output_test_file (str): Output file path for the combined testing dictionary.
196
+ '''
197
+
198
+ # Load the dictionaries from the .pkl files
199
+ train_dicts = [pickle.load(open(file, 'rb')) for file in train_files]
200
+ test_dicts = [pickle.load(open(file, 'rb')) for file in test_files]
201
+
202
+ # Combine the dictionaries
203
+ combined_train_dict = {}
204
+ combined_test_dict = {}
205
+
206
+ for key in train_dicts[0].keys():
207
+ combined_train_dict[key] = np.concatenate([d[key] for d in train_dicts], axis=0)
208
+ combined_test_dict[key] = np.concatenate([d[key] for d in test_dicts], axis=0)
209
+
210
+ # Shuffle
211
+ train_combined_list = list(zip(combined_train_dict['template'], combined_train_dict['source'], combined_train_dict['transformation']))
212
+ test_combined_list = list(zip(combined_test_dict['template'], combined_test_dict['source'], combined_test_dict['transformation']))
213
+
214
+ random.shuffle(train_combined_list)
215
+ random.shuffle(test_combined_list)
216
+
217
+ combined_train_dict['template'], combined_train_dict['source'], combined_train_dict['transformation'] = zip(*train_combined_list)
218
+ combined_test_dict['template'], combined_test_dict['source'], combined_test_dict['transformation'] = zip(*test_combined_list)
219
+
220
+ # Convert back to numpy arrays
221
+ combined_train_dict['template'] = np.array(combined_train_dict['template'])
222
+ combined_train_dict['source'] = np.array(combined_train_dict['source'])
223
+ combined_train_dict['transformation'] = np.array(combined_train_dict['transformation'])
224
+
225
+ combined_test_dict['template'] = np.array(combined_test_dict['template'])
226
+ combined_test_dict['source'] = np.array(combined_test_dict['source'])
227
+ combined_test_dict['transformation'] = np.array(combined_test_dict['transformation'])
228
+
229
+ # Checks
230
+ train_size = len(combined_train_dict['source'])
231
+ test_size = len(combined_test_dict['source'])
232
+ num_points = combined_train_dict['source'].shape[1]
233
+
234
+ assert set(combined_train_dict.keys()) == {'template', 'source', 'transformation'}
235
+ assert set(combined_test_dict.keys()) == {'template', 'source', 'transformation'}
236
+
237
+ expected_shape_3d_train = (train_size, num_points, 3)
238
+ expected_shape_6d_train = (train_size, num_points, 6)
239
+
240
+ assert check_shape(combined_train_dict['template'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {combined_train_dict['template'].shape}"
241
+ assert check_shape(combined_train_dict['source'], expected_shape_3d_train, expected_shape_6d_train), f"Expected shape: {expected_shape_3d_train} or {expected_shape_6d_train}, but got {combined_train_dict['source'].shape}"
242
+ assert combined_train_dict['transformation'].shape == (train_size, 4, 4), f"Expected shape: {(train_size, 4, 4)}, but got {combined_train_dict['transformation'].shape}"
243
+
244
+ expected_shape_3d_test = (test_size, num_points, 3)
245
+ expected_shape_6d_test = (test_size, num_points, 6)
246
+
247
+ assert check_shape(combined_test_dict['template'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {combined_test_dict['template'].shape}"
248
+ assert check_shape(combined_test_dict['source'], expected_shape_3d_test, expected_shape_6d_test), f"Expected shape: {expected_shape_3d_test} or {expected_shape_6d_test}, but got {combined_test_dict['source'].shape}"
249
+ assert combined_test_dict['transformation'].shape == (test_size, 4, 4), f"Expected shape: {(test_size, 4, 4)}, but got {combined_test_dict['transformation'].shape}"
250
+
251
+ # Save the dictionaries
252
+ with open(output_train_file_path, 'wb') as f:
253
+ pickle.dump(combined_train_dict, f)
254
+ print(f"combined_train_dict saved to {output_train_file_path}")
255
+
256
+ with open(output_test_file_path, 'wb') as f:
257
+ pickle.dump(combined_test_dict, f)
258
+ print(f"combined_test_dict saved to {output_train_file_path}")
dataloader/user_data.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ class ClassificationData:
6
+ def __init__(self, data_dict):
7
+ self.data_dict = data_dict
8
+ self.pcs = self.find_attribute('pcs')
9
+ self.labels = self.find_attribute('labels')
10
+ self.check_data()
11
+
12
+ def find_attribute(self, attribute):
13
+ try:
14
+ attribute_data = self.data_dict[attribute]
15
+ except:
16
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
17
+ return attribute_data
18
+
19
+ def check_data(self):
20
+ assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
21
+ assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
22
+
23
+ if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
24
+ if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
25
+
26
+ assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
27
+
28
+
29
+ def __len__(self):
30
+ return self.pcs.shape[0]
31
+
32
+ def __getitem__(self, index):
33
+ return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
34
+
35
+
36
+ class RegistrationData:
37
+ def __init__(self, data_dict):
38
+ self.data_dict = data_dict
39
+ self.template = self.find_attribute('template')
40
+ self.source = self.find_attribute('source')
41
+ self.transformation = self.find_attribute('transformation')
42
+ self.check_data()
43
+
44
+ # def find_attribute(self, attribute):
45
+ # try:
46
+ # attribute_data = self.data[attribute]
47
+ # except:
48
+ # print("Given data directory has no key attribute \"{}\"".format(attribute))
49
+ # return attribute_data
50
+
51
+ def find_attribute(self, attribute):
52
+ attribute_data = None
53
+ if attribute in self.data_dict:
54
+ attribute_data = self.data_dict[attribute]
55
+ else:
56
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
57
+ return attribute_data
58
+
59
+ def check_data(self):
60
+ assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
61
+ assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
62
+ assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
63
+
64
+ if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
65
+ if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
66
+ if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
67
+
68
+ assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
69
+ assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
70
+
71
+ def __len__(self):
72
+ return self.template.shape[0]
73
+
74
+ def __getitem__(self, index):
75
+ return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
76
+
77
+
78
+ class FlowData:
79
+ def __init__(self, data_dict):
80
+ self.data_dict = data_dict
81
+ self.frame1 = self.find_attribute('frame1')
82
+ self.frame2 = self.find_attribute('frame2')
83
+ self.flow = self.find_attribute('flow')
84
+ self.check_data()
85
+
86
+ def find_attribute(self, attribute):
87
+ try:
88
+ attribute_data = self.data[attribute]
89
+ except:
90
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
91
+ return attribute_data
92
+
93
+ def check_data(self):
94
+ assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
95
+ assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
96
+ assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
97
+
98
+ if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
99
+ if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
100
+ if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
101
+
102
+ assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
103
+ assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
104
+
105
+ def __len__(self):
106
+ return self.frame1.shape[0]
107
+
108
+ def __getitem__(self, index):
109
+ return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
110
+
111
+
112
+ class UserData:
113
+ def __init__(self, application, data_dict):
114
+ self.application = application
115
+
116
+ if self.application == 'classification':
117
+ self.data_class = ClassificationData(data_dict)
118
+ elif self.application == 'registration':
119
+ self.data_class = RegistrationData(data_dict)
120
+ elif self.application == 'flow_estimation':
121
+ self.data_class = FlowData(data_dict)
122
+
123
+ def __len__(self):
124
+ return len(self.data_class)
125
+
126
+ def __getitem__(self, index):
127
+ return self.data_class[index]
environment.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: r3pm_net
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ dependencies:
6
+ - python=3.9
7
+ - pip
8
+ - open3d
9
+ - pytorch
10
+ - hatchling
11
+ - ipykernel
12
+ - pip:
13
+ - tabulate
14
+ - pyyaml
15
+ - -e .
16
+
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "r3pm_net"
7
+ version = "0.1.0"
8
+ description = "R3PM-Net point cloud registration"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ dependencies = [
12
+ "numpy",
13
+ "torch",
14
+ "tensorboardX",
15
+ "tqdm",
16
+ "pyyaml",
17
+ "open3d",
18
+ "tabulate",
19
+ ]
20
+
21
+ [tool.hatch.build.targets.wheel]
22
+ packages = ["r3pm_net", "tools", "dataloader", "thirdparty"]
23
+
24
+ [tool.jupytext]
25
+ formats = "ipynb,py:light"
r3pm_net/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """R3PM-Net: point cloud registration with PointNet features."""
2
+
3
+ from .model import R3PMNet
4
+
5
+ __all__ = ["R3PMNet"]
r3pm_net/config_loader.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load YAML training config and merge with argparse."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Mapping
9
+
10
+ import yaml
11
+
12
+ from r3pm_net.paths import REPO_ROOT
13
+
14
+
15
+ def _resolve_maybe_relative(path_str: str | None) -> str | None:
16
+ if path_str is None or path_str == "":
17
+ return path_str
18
+ p = Path(path_str)
19
+ if p.is_absolute():
20
+ return str(p)
21
+ return str(REPO_ROOT / p)
22
+
23
+
24
+ def load_yaml_config(path: str | Path) -> dict[str, Any]:
25
+ with open(path, "r", encoding="utf-8") as f:
26
+ data = yaml.safe_load(f)
27
+ if data is None:
28
+ return {}
29
+ if not isinstance(data, Mapping):
30
+ raise ValueError(f"Config must be a mapping, got {type(data)}")
31
+ return dict(data)
32
+
33
+
34
+ def _extract_config_argv(argv: list[str], default_cfg: str) -> tuple[str, list[str]]:
35
+ """Return (config path for YAML, argv without --config ...)."""
36
+ path = default_cfg
37
+ out: list[str] = []
38
+ i = 0
39
+ while i < len(argv):
40
+ if argv[i] == "--config" and i + 1 < len(argv):
41
+ path = argv[i + 1]
42
+ i += 2
43
+ continue
44
+ if argv[i].startswith("--config="):
45
+ path = argv[i].split("=", 1)[1]
46
+ i += 1
47
+ continue
48
+ out.append(argv[i])
49
+ i += 1
50
+ return path, out
51
+
52
+
53
+ def parse_train_args(argv: list[str], build_parser) -> argparse.Namespace:
54
+ """Load YAML from --config (default: config/default.yaml), merge as argparse defaults, then parse CLI."""
55
+ default_cfg = str(REPO_ROOT / "config" / "default.yaml")
56
+ cfg_path, argv_rest = _extract_config_argv(list(argv), default_cfg)
57
+ cfg = load_yaml_config(cfg_path) if Path(cfg_path).is_file() else {}
58
+ parser = build_parser(cfg_path)
59
+ if cfg:
60
+ known = {
61
+ a.dest
62
+ for a in parser._actions
63
+ if getattr(a, "dest", None) and a.dest not in ("help", argparse.SUPPRESS)
64
+ }
65
+ filtered = {k: v for k, v in cfg.items() if k in known}
66
+ parser.set_defaults(**filtered)
67
+ return parser.parse_args(argv_rest)
68
+
69
+
70
+ def resolve_path_args(ns: Any, path_keys: tuple[str, ...]) -> None:
71
+ """Mutate namespace: resolve listed keys to absolute paths under REPO_ROOT when relative."""
72
+ for key in path_keys:
73
+ val = getattr(ns, key, None)
74
+ if isinstance(val, str) and val:
75
+ setattr(ns, key, _resolve_maybe_relative(val))
76
+
77
+
78
+ def load_eval_yaml() -> dict[str, Any]:
79
+ """Load ``config/eval.yaml`` if present; otherwise return an empty dict."""
80
+ path = REPO_ROOT / "config" / "eval.yaml"
81
+ if not path.is_file():
82
+ return {}
83
+ return load_yaml_config(path)
84
+
85
+
86
+ def get_pretrained_rpmnet_dir() -> str:
87
+ """Directory containing ``clean-trained.pth``, ``best_model_PointNet*.t7``, etc.
88
+
89
+ ``R3PM_NET_PRETRAINED_ROOT`` overrides ``pretrained_rpmnet_dir`` in ``config/eval.yaml``.
90
+ """
91
+ env = os.environ.get("R3PM_NET_PRETRAINED_ROOT")
92
+ if env:
93
+ return str(Path(env).expanduser().resolve())
94
+ cfg = load_eval_yaml()
95
+ rel = (cfg.get("pretrained_rpmnet_dir") or "checkpoints").strip()
96
+ if not rel:
97
+ rel = "checkpoints"
98
+ out = _resolve_maybe_relative(rel)
99
+ return out if out else str(REPO_ROOT / "checkpoints")
100
+
101
+
102
+ def get_sioux_data_root() -> str:
103
+ """Base data directory for Sioux scripts (``data`` / ``sioux_cranfield``, etc.)."""
104
+ cfg = load_eval_yaml()
105
+ sioux = cfg.get("sioux") or {}
106
+ base = sioux.get("base_dir") or cfg.get("data_root") or "data"
107
+ out = _resolve_maybe_relative(str(base).strip())
108
+ return out if out else str(REPO_ROOT / "data")
109
+
110
+
111
+ def get_modelnet40_paths() -> tuple[str, str]:
112
+ """Return ``(dataset_path, cache_dir)`` for ModelNet40 evaluation."""
113
+ cfg = load_eval_yaml()
114
+ m = cfg.get("modelnet40") or {}
115
+ ds = m.get("dataset_path", "data/ModelNet40")
116
+ cache = m.get("cache_dir", "data/down_sampled_modelnet40")
117
+ dsr = _resolve_maybe_relative(ds)
118
+ cr = _resolve_maybe_relative(cache)
119
+ return (
120
+ dsr if dsr else str(REPO_ROOT / "data" / "ModelNet40"),
121
+ cr if cr else str(REPO_ROOT / "data" / "down_sampled_modelnet40"),
122
+ )
123
+
124
+
125
+ def get_method_paths() -> dict[str, Any]:
126
+ """Return resolved path configuration for external registration methods."""
127
+ cfg = load_eval_yaml()
128
+ methods = cfg.get("methods") or {}
129
+ out: dict[str, Any] = {}
130
+ for method_name, method_cfg in methods.items():
131
+ if not isinstance(method_cfg, Mapping):
132
+ continue
133
+ method_out: dict[str, Any] = {}
134
+ for k, v in method_cfg.items():
135
+ if isinstance(v, str) and v.strip():
136
+ rv = _resolve_maybe_relative(v.strip())
137
+ method_out[k] = rv if rv else v
138
+ else:
139
+ method_out[k] = v
140
+ out[str(method_name)] = method_out
141
+ return out
142
+
143
+
144
+ def get_sioux_paths() -> dict[str, Any]:
145
+ """Return Sioux eval paths from config/eval.yaml with absolute paths."""
146
+ cfg = load_eval_yaml()
147
+ sioux = cfg.get("sioux") or {}
148
+ out: dict[str, Any] = {}
149
+ for k, v in sioux.items():
150
+ if isinstance(v, str) and v.strip():
151
+ rv = _resolve_maybe_relative(v.strip())
152
+ out[k] = rv if rv else v
153
+ elif isinstance(v, list):
154
+ vals = []
155
+ for item in v:
156
+ if isinstance(item, str) and item.strip():
157
+ rv = _resolve_maybe_relative(item.strip())
158
+ vals.append(rv if rv else item)
159
+ else:
160
+ vals.append(item)
161
+ out[k] = vals
162
+ else:
163
+ out[k] = v
164
+ return out
r3pm_net/feature_extractor.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Feature extractor for R3PM-Net
2
+ '''
3
+ Unfortunately, the feature extractor cannot be provided in the repository due to copyright issues.
4
+ Please implement the feature extractor for R3PM-Net as described in the paper and place it in this file.
5
+ Currently, the feature extractor is set to PPFNet (same as RPMNet).
6
+ '''
7
+ from thirdparty.learning3d.models import PPFNet
8
+ feature_extractor = PPFNet()
r3pm_net/model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from thirdparty.learning3d.utils import square_distance, angle_difference
8
+ from thirdparty.learning3d.ops.transform_functions import convert2transformation
9
+
10
+ _EPS = 1e-5 # To prevent division by zero
11
+
12
+ class ParameterPredictionNet(nn.Module):
13
+ def __init__(self, weights_dim):
14
+ """PointNet based Parameter prediction network
15
+
16
+ Args:
17
+ weights_dim: Number of weights to predict (excluding beta), should be something like
18
+ [3], or [64, 3], for 3 types of features
19
+ """
20
+
21
+ super().__init__()
22
+
23
+ self._logger = logging.getLogger(self.__class__.__name__)
24
+
25
+ self.weights_dim = weights_dim
26
+
27
+ # Pointnet
28
+ self.prepool = nn.Sequential(
29
+ nn.Conv1d(4, 64, 1),
30
+ nn.GroupNorm(8, 64),
31
+ nn.ReLU(),
32
+
33
+ nn.Conv1d(64, 64, 1),
34
+ nn.GroupNorm(8, 64),
35
+ nn.ReLU(),
36
+
37
+ nn.Conv1d(64, 64, 1),
38
+ nn.GroupNorm(8, 64),
39
+ nn.ReLU(),
40
+
41
+ nn.Conv1d(64, 128, 1),
42
+ nn.GroupNorm(8, 128),
43
+ nn.ReLU(),
44
+
45
+ nn.Conv1d(128, 1024, 1),
46
+ nn.GroupNorm(16, 1024),
47
+ nn.ReLU(),
48
+ )
49
+ self.pooling = nn.AdaptiveMaxPool1d(1)
50
+ self.postpool = nn.Sequential(
51
+ nn.Linear(1024, 512),
52
+ nn.GroupNorm(16, 512),
53
+ nn.ReLU(),
54
+
55
+ nn.Linear(512, 256),
56
+ nn.GroupNorm(16, 256),
57
+ nn.ReLU(),
58
+
59
+ nn.Linear(256, 2 + np.prod(weights_dim)),
60
+ )
61
+
62
+ self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim))
63
+
64
+ def forward(self, x):
65
+ """ Returns alpha, beta, and gating_weights (if needed)
66
+
67
+ Args:
68
+ x: List containing two point clouds, x[0] = src (B, J, 3), x[1] = ref (B, K, 3)
69
+
70
+ Returns:
71
+ beta, alpha, weightings
72
+ """
73
+ # X and Y concatenated
74
+ src_padded = F.pad(x[0], (0, 1), mode='constant', value=0)
75
+ ref_padded = F.pad(x[1], (0, 1), mode='constant', value=1)
76
+ concatenated = torch.cat([src_padded, ref_padded], dim=1)
77
+
78
+ prepool_feat = self.prepool(concatenated.permute(0, 2, 1))
79
+ pooled = torch.flatten(self.pooling(prepool_feat), start_dim=-2)
80
+ raw_weights = self.postpool(pooled)
81
+
82
+ # softplus to ensure positivity
83
+ beta = F.softplus(raw_weights[:, 0])
84
+ alpha = F.softplus(raw_weights[:, 1])
85
+
86
+ return beta, alpha
87
+
88
+
89
+
90
+ def to_numpy(tensor):
91
+ """Wrapper around .detach().cpu().numpy() """
92
+ if isinstance(tensor, torch.Tensor):
93
+ return tensor.detach().cpu().numpy()
94
+ elif isinstance(tensor, np.ndarray):
95
+ return tensor
96
+ else:
97
+ raise NotImplementedError
98
+
99
+
100
+ def se3_transform(g, a, normals=None):
101
+ """ Applies the SE3 transform
102
+
103
+ Args:
104
+ g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
105
+ a: Points to be transformed (N, 3) or (B, N, 3)
106
+ normals: (Optional). If provided, normals will be transformed
107
+
108
+ Returns:
109
+ transformed points of size (N, 3) or (B, N, 3)
110
+
111
+ """
112
+ R = g[..., :3, :3] # (B, 3, 3)
113
+ p = g[..., :3, 3] # (B, 3)
114
+
115
+ if len(g.size()) == len(a.size()):
116
+ b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
117
+ else:
118
+ raise NotImplementedError
119
+ b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked
120
+
121
+ if normals is not None:
122
+ rotated_normals = normals @ R.transpose(-1, -2)
123
+ return b, rotated_normals
124
+
125
+ else:
126
+ return b
127
+
128
+ def match_features(feat_src, feat_ref, metric='l2'):
129
+ """ Compute pairwise distance between features
130
+
131
+ Args:
132
+ feat_src: (B, J, C)
133
+ feat_ref: (B, K, C)
134
+ metric: either 'angle' or 'l2' (squared euclidean)
135
+
136
+ Returns:
137
+ Matching matrix (B, J, K). i'th row describes how well the i'th point
138
+ in the src agrees with every point in the ref.
139
+ """
140
+ if feat_src.shape[-1] != feat_ref.shape[-1]:
141
+ if feat_src.shape[-1] > feat_ref.shape[-1]:
142
+ feat_src = feat_src[:,:,:feat_ref.shape[-1]]
143
+ elif feat_src.shape[-1] < feat_ref.shape[-1]:
144
+ feat_ref = feat_ref[:,:,:feat_src.shape[-1]]
145
+
146
+ assert feat_src.shape[-1] == feat_ref.shape[-1]
147
+
148
+ if metric == 'l2':
149
+ dist_matrix = square_distance(feat_src, feat_ref)
150
+ elif metric == 'angle':
151
+ feat_src_norm = feat_src / (torch.norm(feat_src, dim=-1, keepdim=True) + _EPS)
152
+ feat_ref_norm = feat_ref / (torch.norm(feat_ref, dim=-1, keepdim=True) + _EPS)
153
+
154
+ dist_matrix = angle_difference(feat_src_norm, feat_ref_norm)
155
+ else:
156
+ raise NotImplementedError
157
+
158
+ return dist_matrix
159
+
160
+
161
+ def sinkhorn(log_alpha, n_iters: int = 5, slack: bool = True, eps: float = -1) -> torch.Tensor:
162
+ """ Run sinkhorn iterations to generate a near doubly stochastic matrix, where each row or column sum to <=1
163
+
164
+ Args:
165
+ log_alpha: log of positive matrix to apply sinkhorn normalization (B, J, K)
166
+ n_iters (int): Number of normalization iterations
167
+ slack (bool): Whether to include slack row and column
168
+ eps: eps for early termination (Used only for handcrafted RPM). Set to negative to disable.
169
+
170
+ Returns:
171
+ log(perm_matrix): Doubly stochastic matrix (B, J, K)
172
+
173
+ Modified from original source taken from:
174
+ Learning Latent Permutations with Gumbel-Sinkhorn Networks
175
+ https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch
176
+ """
177
+
178
+ # Sinkhorn iterations
179
+ prev_alpha = None
180
+ if slack:
181
+ zero_pad = nn.ZeroPad2d((0, 1, 0, 1))
182
+ log_alpha_padded = zero_pad(log_alpha[:, None, :, :])
183
+
184
+ log_alpha_padded = torch.squeeze(log_alpha_padded, dim=1)
185
+
186
+ for i in range(n_iters):
187
+ # Row normalization
188
+ log_alpha_padded = torch.cat((
189
+ log_alpha_padded[:, :-1, :] - (torch.logsumexp(log_alpha_padded[:, :-1, :], dim=2, keepdim=True)),
190
+ log_alpha_padded[:, -1, None, :]), # Don't normalize last row
191
+ dim=1)
192
+
193
+ # Column normalization
194
+ log_alpha_padded = torch.cat((
195
+ log_alpha_padded[:, :, :-1] - (torch.logsumexp(log_alpha_padded[:, :, :-1], dim=1, keepdim=True)),
196
+ log_alpha_padded[:, :, -1, None]), # Don't normalize last column
197
+ dim=2)
198
+
199
+ if eps > 0:
200
+ if prev_alpha is not None:
201
+ abs_dev = torch.abs(torch.exp(log_alpha_padded[:, :-1, :-1]) - prev_alpha)
202
+ if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
203
+ break
204
+ prev_alpha = torch.exp(log_alpha_padded[:, :-1, :-1]).clone()
205
+
206
+ log_alpha = log_alpha_padded[:, :-1, :-1]
207
+ else:
208
+ for i in range(n_iters):
209
+ # Row normalization (i.e. each row sum to 1)
210
+ log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=2, keepdim=True))
211
+
212
+ # Column normalization (i.e. each column sum to 1)
213
+ log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=1, keepdim=True))
214
+
215
+ if eps > 0:
216
+ if prev_alpha is not None:
217
+ abs_dev = torch.abs(torch.exp(log_alpha) - prev_alpha)
218
+ if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
219
+ break
220
+ prev_alpha = torch.exp(log_alpha).clone()
221
+
222
+ return log_alpha
223
+
224
+
225
+ def compute_rigid_transform(a: torch.Tensor, b: torch.Tensor, weights: torch.Tensor):
226
+ """Compute rigid transforms between two point sets
227
+
228
+ Args:
229
+ a (torch.Tensor): (B, M, 3) points
230
+ b (torch.Tensor): (B, N, 3) points
231
+ weights (torch.Tensor): (B, M)
232
+
233
+ Returns:
234
+ Transform T (B, 3, 4) to get from a to b, i.e. T*a = b
235
+ """
236
+
237
+ weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS)
238
+ centroid_a = torch.sum(a * weights_normalized, dim=1)
239
+ centroid_b = torch.sum(b * weights_normalized, dim=1)
240
+ a_centered = a - centroid_a[:, None, :]
241
+ b_centered = b - centroid_b[:, None, :]
242
+ cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
243
+
244
+ # Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
245
+ # and choose based on determinant to avoid flips
246
+ u, s, v = torch.svd(cov, some=False, compute_uv=True)
247
+ rot_mat_pos = v @ u.transpose(-1, -2)
248
+ v_neg = v.clone()
249
+ v_neg[:, :, 2] *= -1
250
+ rot_mat_neg = v_neg @ u.transpose(-1, -2)
251
+ rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
252
+ assert torch.all(torch.det(rot_mat) > 0)
253
+
254
+ # Compute translation (uncenter centroid)
255
+ translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
256
+
257
+ transform = torch.cat((rot_mat, translation), dim=2)
258
+ return transform
259
+
260
+ class R3PMNet(nn.Module):
261
+ def __init__(self, feature_model):
262
+
263
+ super().__init__()
264
+
265
+ self.add_slack = True
266
+ self.num_sk_iter = 5
267
+
268
+ self.weights_net = ParameterPredictionNet(weights_dim=[0])
269
+ self.feat_extractor = feature_model
270
+
271
+ def compute_affinity(self, beta, feat_distance, alpha=0.5):
272
+ """Compute logarithm of Initial match matrix values, i.e. log(m_jk)"""
273
+ if isinstance(alpha, float):
274
+ hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha)
275
+ else:
276
+ hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha[:, None, None])
277
+ return hybrid_affinity
278
+
279
+ @staticmethod
280
+ def split_normals(data):
281
+ if data.shape[2] == 6:
282
+ xyz, normals = data[:, :, :3], data[:, :, 3:6]
283
+ elif data.shape[2] == 3:
284
+ xyz, normals = data, torch.zeros(data.shape).to(data.device)
285
+ return xyz, normals
286
+
287
+ def spam(self, xyz_template, norm_template, xyz_source, norm_source):
288
+ self.beta, self.alpha = self.weights_net([xyz_source, xyz_template])
289
+
290
+ try: # R3PMNET feature extractor
291
+ self.feat_source = self.feat_extractor(xyz_source)
292
+ self.feat_template = self.feat_extractor(xyz_template)
293
+ except:
294
+ self.feat_source = self.feat_extractor(xyz_source, norm_source)
295
+ self.feat_template = self.feat_extractor(xyz_template, norm_template)
296
+
297
+ feat_distance = match_features(self.feat_source, self.feat_template)
298
+ self.affinity = self.compute_affinity(self.beta, feat_distance, alpha=self.alpha)
299
+
300
+ # Compute weighted coordinates
301
+ log_perm_matrix = sinkhorn(self.affinity, n_iters=self.num_sk_iter, slack=self.add_slack)
302
+ self.perm_matrix = torch.exp(log_perm_matrix)
303
+
304
+ try: # R3PMNET features
305
+ weighted_template = self.perm_matrix @ xyz_template[:,:self.perm_matrix.shape[1]] / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
306
+ except:
307
+ weighted_template = self.perm_matrix @ xyz_template / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
308
+ return weighted_template
309
+
310
+ def forward(self, template, source, max_iterations: int = 1):
311
+ """Forward pass for R3PM-Net
312
+
313
+ Args:
314
+ data: Dict containing the following fields:
315
+ 'points_src': Source points (B, J, 6)
316
+ 'points_ref': Reference points (B, K, 6)
317
+ num_iter (int): Number of iterations. Recommended to be 2 for training
318
+
319
+ Returns:
320
+ transform: Transform to apply to source points such that they align to reference
321
+ src_transformed: Transformed source points
322
+ """
323
+
324
+ xyz_template, norm_template = self.split_normals(template)
325
+ xyz_source, norm_source = self.split_normals(source)
326
+
327
+ xyz_source_t, norm_source_t = xyz_source, norm_source # a copy of source to apply transformation to
328
+
329
+ transforms = []
330
+ all_gamma, all_perm_matrices, all_weighted_template = [], [], []
331
+ all_beta, all_alpha = [], []
332
+
333
+ for i in range(max_iterations):
334
+ weighted_template = self.spam(xyz_template, norm_template, xyz_source_t, norm_source_t) # Finding better correspondences after each iteration.
335
+
336
+ # Compute transform and transform points
337
+ try: # R3PMNET features
338
+ transform = compute_rigid_transform(xyz_source[:,:weighted_template.shape[1]], weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
339
+ xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source[:,:weighted_template.shape[1]], norm_source) # Apply transformation to original source.
340
+ except:
341
+ transform = compute_rigid_transform(xyz_source_t, weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
342
+ xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source, norm_source) # Apply transformation to original source.
343
+
344
+
345
+ transforms.append(transform)
346
+ all_gamma.append(torch.exp(self.affinity))
347
+ all_perm_matrices.append(self.perm_matrix)
348
+ all_weighted_template.append(weighted_template)
349
+ all_beta.append(to_numpy(self.beta))
350
+ all_alpha.append(to_numpy(self.alpha))
351
+
352
+ est_T = convert2transformation(transforms[max_iterations-1][:, :3, :3], transforms[max_iterations-1][:, :3, 3])
353
+ transformed_source = torch.bmm(est_T[:, :3, :3], source[:,:,:3].permute(0, 2, 1)).permute(0, 2, 1) + est_T[:, :3, 3].unsqueeze(1)
354
+
355
+ try: # for training
356
+ result = {'est_R': est_T[:, :3, :3], # source -> template
357
+ 'est_t': est_T[:, :3, 3], # source -> template
358
+ 'est_T': est_T, # source -> template
359
+ 'r': self.feat_template - self.feat_source,
360
+ 'transformed_source': transformed_source}
361
+ except RuntimeError:
362
+ result = {'est_R': est_T[:, :3, :3], # source -> template
363
+ 'est_t': est_T[:, :3, 3], # source -> template
364
+ 'est_T': est_T, # source -> template
365
+ 'transformed_source': transformed_source}
366
+
367
+ result['perm_matrices_init'] = all_gamma
368
+ result['perm_matrices'] = all_perm_matrices
369
+ result['weighted_template'] = all_weighted_template
370
+ result['beta'] = np.stack(all_beta, axis=0)
371
+ result['alpha'] = np.stack(all_alpha, axis=0)
372
+ result['transforms'] = transforms
373
+
374
+ return result
375
+
376
+
377
+ if __name__ == '__main__':
378
+ template, source = torch.rand(10,1024,6), torch.rand(10,1024,6)
379
+
380
+ net = R3PMNet()
381
+ result = net(template, source)
382
+ import ipdb; ipdb.set_trace()
r3pm_net/paths.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Repository-root resolution for portable paths."""
2
+
3
+ from pathlib import Path
4
+
5
+ # r3pm_net/paths.py -> parents[1] is the repository root
6
+ REPO_ROOT = Path(__file__).resolve().parents[1]
7
+
8
+
9
+ def repo_path(*parts: str) -> str:
10
+ """Join path segments relative to the repository root."""
11
+ return str(REPO_ROOT.joinpath(*parts))
scripts/eval_modelnet40.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ # Repository root on PYTHONPATH (run: python scripts/test_modelnet40.py from repo root).
8
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
9
+ if str(_REPO_ROOT) not in sys.path:
10
+ sys.path.insert(0, str(_REPO_ROOT))
11
+
12
+ import argparse
13
+ import random
14
+
15
+ import numpy as np
16
+ import open3d as o3d
17
+ import torch
18
+ from tqdm import tqdm
19
+
20
+ from tools import augmentation, data, l3d_helper, print_results, transformations
21
+ from tools import l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation
22
+ from r3pm_net.config_loader import get_method_paths,get_modelnet40_paths, get_pretrained_rpmnet_dir
23
+
24
+ '''
25
+ This script evaluates the performance on the ModelNet40 test dataset.
26
+ The results are averaged ovet the dataset with 2468 samples.
27
+ All the point clouds are normalized to a sphere of radius 1.
28
+
29
+ Augmentations:
30
+ - Transformation = Random rotation (0 - 45) and translation (-0.5 to 0.5)
31
+ - Noise = Gaussian noise with mean 0 and std deviation of 0.01 [optional]
32
+ - Outliers = with level 1 which means 2% of the points are outliers (PC size = 2040) [optional]
33
+ - Occlusion = 90000 radius which means 0.7% of the points are occluded (PC size = 1986) [optional]
34
+ '''
35
+ def set_seed(seed: int) -> None:
36
+ os.environ["PYTHONHASHSEED"] = str(seed)
37
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
38
+
39
+ random.seed(seed)
40
+ np.random.seed(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+
44
+ torch.backends.cudnn.benchmark = False
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.use_deterministic_algorithms(True)
47
+
48
+ # arguments
49
+ parser = argparse.ArgumentParser(description="ModelNet40 R3PM-Net evaluation")
50
+ parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)")
51
+
52
+ args = parser.parse_args()
53
+ set_seed(args.seed)
54
+ method_paths = get_method_paths()
55
+
56
+ pretrained_base_dir = get_pretrained_rpmnet_dir()
57
+ _path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth")
58
+ _path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet.t7") #TODO: CHANGE
59
+
60
+ def fix_off_file(file_path):
61
+ with open(file_path, 'r') as f:
62
+ lines = f.readlines()
63
+
64
+ if lines[0].startswith("OFF") and len(lines[0].strip().split()) > 1:
65
+ header = lines[0].strip()
66
+ new_header = "OFF\n" + header[3:] + "\n"
67
+ lines = [new_header] + lines[1:]
68
+
69
+ with open(file_path, 'w') as f:
70
+ f.writelines(lines)
71
+ print(f"Fixed: {file_path}")
72
+
73
+ def load_modelnet40_test_data(dataset_path, num_points=2000):
74
+ test_data = []
75
+ test_labels = []
76
+ categories = os.listdir(dataset_path)
77
+ for label, category in enumerate(tqdm(categories, desc="Loading Data")):
78
+ test_dir = os.path.join(dataset_path, category, 'test')
79
+ if not os.path.exists(test_dir):
80
+ continue
81
+ for file in tqdm(os.listdir(test_dir), desc=f"Processing {category} Category", leave=False):
82
+ if file.endswith('.off'):
83
+ file_path = os.path.join(test_dir, file)
84
+ mesh = o3d.io.read_triangle_mesh(file_path)
85
+ point_cloud = mesh.sample_points_poisson_disk(number_of_points=num_points)
86
+ test_data.append(point_cloud)
87
+ test_labels.append(label)
88
+
89
+ return test_data, test_labels, categories
90
+
91
+ # download from http://modelnet.cs.princeton.edu/ModelNet40.zip unzip and put the path in the config/eval.yaml
92
+ dataset_path, save_dir = get_modelnet40_paths()
93
+ test_data_path = os.path.join(save_dir, "test_data.npy")
94
+ test_labels_path = os.path.join(save_dir, "test_labels.npy")
95
+ categories_path = os.path.join(save_dir, "categories.npy")
96
+
97
+ os.makedirs(save_dir, exist_ok=True)
98
+
99
+ # Check if data already exists
100
+ if os.path.exists(test_data_path) and os.path.exists(test_labels_path) and os.path.exists(categories_path):
101
+ print("Loading existing test data...")
102
+ test_data_np = np.load(test_data_path, allow_pickle=True)
103
+ test_labels = np.load(test_labels_path)
104
+ categories = np.load(categories_path)
105
+ print("Done! Testing the models...")
106
+ else:
107
+ print("Loading and processing ModelNet40 test data...")
108
+ # Fix all .OFF files in the dataset
109
+ for root, _, files in os.walk(dataset_path):
110
+ for file in files:
111
+ if file.endswith(".off"):
112
+ fix_off_file(os.path.join(root, file))
113
+
114
+ test_data, test_labels, categories = load_modelnet40_test_data(dataset_path)
115
+
116
+ test_data_np = [data.normalize_pc(pc, return_as_np = True) for pc in test_data]
117
+
118
+ np.save(test_data_path, test_data_np)
119
+ np.save(test_labels_path, test_labels)
120
+ np.save(categories_path, categories)
121
+ print("Test data saved!")
122
+
123
+ # Initialize arrays to store results
124
+ rpm_results_all = []
125
+ predator_results_all = []
126
+ geotransformer_results_all = []
127
+ logdesc_results_all = []
128
+ regtr_results_all = []
129
+ r3pm_net_results_all = []
130
+ tuned_r3pm_net_results_all = []
131
+
132
+ rpm_reg_results_all = []
133
+ predator_reg_results_all = []
134
+ geotransformer_reg_results_all = []
135
+ logdesc_reg_results_all = []
136
+ regtr_reg_results_all = []
137
+ r3pm_net_reg_results_all = []
138
+ tuned_r3pm_net_reg_results_all = []
139
+
140
+ all_sources = []
141
+ all_targets = []
142
+ all_angles ={}
143
+
144
+ # Reconstruct Open3D PointCloud objects from saved npy arrays
145
+ test_data = [o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points)) for points in test_data_np]
146
+
147
+ noise_level = 0
148
+ outlier_level = 0
149
+ outlier_lowerbound = -0.5
150
+ outlier_upperbound = 0.5
151
+ # occlusion_level = 90000 # Higher value means less occlusion
152
+ occlusion_level = 0 # Higher value means less occlusion
153
+
154
+
155
+ # set arguments for models
156
+ rpm_args = l3d_helper.options(modelName="RPMNet")
157
+ rpm_args.pretrained = _path_zs
158
+
159
+ # OverlapPredator (used by Predator runner)
160
+ predator_cfg = method_paths.get("predator", {})
161
+ predator_root = predator_cfg.get("root")
162
+ predator_config_path = predator_cfg.get("config_path")
163
+ predator_weights_path = predator_cfg.get("weights_path")
164
+
165
+ # GeoTransformer
166
+ geo_cfg = method_paths.get("geotransformer", {})
167
+ geotransformer_root = geo_cfg.get("root")
168
+ geotransformer_exp_subdir = geo_cfg.get("exp_subdir")
169
+ geotransformer_weights_path = geo_cfg.get("weights_path")
170
+
171
+ # LoGDesc
172
+ logdesc_cfg = method_paths.get("logdesc", {})
173
+ logdesc_root = logdesc_cfg.get("root")
174
+ logdesc_weights_path = logdesc_cfg.get("weights_path")
175
+
176
+ # RegTR
177
+ regtr_cfg = method_paths.get("regtr", {})
178
+ regtr_root = regtr_cfg.get("root")
179
+ regtr_ckpt_path = regtr_cfg.get("ckpt_path")
180
+ regtr_config_path = regtr_cfg.get("config_path")
181
+
182
+ # R3PM-Net (ours) - ZS - no training
183
+ r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
184
+ r3pm_net_args.pretrained = _path_zs
185
+
186
+ # R3PM-Net (ours) - FT
187
+ tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
188
+ tuned_r3pm_net_args.pretrained = _path_ft
189
+
190
+ for i, item in enumerate(tqdm(test_data, desc="Testing methods")):
191
+
192
+ # Simulate data
193
+ x_angle = int(random.uniform(0, 45))
194
+ y_angle = int(random.uniform(0, 45))
195
+ z_angle = int(random.uniform(0, 45))
196
+ translation_range = (-0.5, 0.5)
197
+ gt_transformation = transformations.create_transformation(x_angle, y_angle, z_angle, translation_range)
198
+ source = copy.deepcopy(item)
199
+
200
+ target = copy.deepcopy(item).transform(gt_transformation)
201
+
202
+ # Apply augmentations
203
+ noisy_source = copy.deepcopy(source)
204
+ if noise_level != 0:
205
+ noisy_source = augmentation.apply_noise(noisy_source, noise_level)
206
+ if outlier_level != 0:
207
+ noisy_source = augmentation.add_outliers(noisy_source, outlier_level, outlier_lowerbound, outlier_upperbound)
208
+ if occlusion_level != 0:
209
+ noisy_source, _ = augmentation.apply_occlusion(noisy_source, occlusion_level)
210
+ if len(noisy_source.points) < 1024: # cannot be smaller than embedding dims in config/default.yaml
211
+ noisy_source = copy.deepcopy(source)
212
+ noisy_source = augmentation.apply_noise(noisy_source, noise_level)
213
+ noisy_source, _ = augmentation.apply_occlusion(noisy_source, occlusion_level * 100)
214
+ assert len(noisy_source.points) >= 1024, "Noisy source point cloud has less than 1024 points."
215
+
216
+ # RPMNet
217
+ rpm_results_pc, rpm_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
218
+ noisy_source, target, 'rpmnet', gt_transformation, rpm_args)
219
+ rpm_results_all.append(rpm_results)
220
+ rpm_reg_results_all.append(rpm_results_pc)
221
+
222
+ # OverlapPredator
223
+ predator_results_pc, predator_results = predator_registration_and_evaluation.predator_reg_and_eval(
224
+ noisy_source,
225
+ target,
226
+ gt_transformation=gt_transformation,
227
+ predator_root=predator_root,
228
+ config_path=predator_config_path,
229
+ weights_path=predator_weights_path,
230
+ ransac_n_points=1000,
231
+ ransac_distance_threshold=0.05,
232
+ ransac_n=3,
233
+ sampling="prob",
234
+ mutual=False,
235
+ input_num_points=1024,
236
+ )
237
+ predator_results_all.append(predator_results)
238
+ predator_reg_results_all.append(predator_results_pc)
239
+
240
+ # GeoTransformer (ModelNet)
241
+ geotransformer_results_pc, geotransformer_results = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval(
242
+ noisy_source,
243
+ target,
244
+ gt_transformation=gt_transformation,
245
+ geotransformer_root=geotransformer_root,
246
+ exp_subdir=geotransformer_exp_subdir,
247
+ weights_path=geotransformer_weights_path,
248
+ )
249
+ geotransformer_results_all.append(geotransformer_results)
250
+ geotransformer_reg_results_all.append(geotransformer_results_pc)
251
+
252
+ # LoGDesc
253
+ logdesc_results_pc, logdesc_results = logdesc_registration_and_evaluation.logdesc_reg_and_eval(
254
+ noisy_source,
255
+ target,
256
+ gt_transformation=gt_transformation,
257
+ logdesc_root=logdesc_root,
258
+ weights_path=logdesc_weights_path,
259
+ max_keypoints=768,
260
+ num_points_per_sample=128,
261
+ sample_radius=0.3,
262
+ topk_matches=128,
263
+ use_kpt=False,
264
+ )
265
+ logdesc_results_all.append(logdesc_results)
266
+ logdesc_reg_results_all.append(logdesc_results_pc)
267
+
268
+ # RegTR (ModelNet)
269
+ regtr_results_pc, regtr_results = regtr_registration_and_evaluation.regtr_reg_and_eval(
270
+ noisy_source,
271
+ target,
272
+ gt_transformation=gt_transformation,
273
+ regtr_root=regtr_root,
274
+ ckpt_path=regtr_ckpt_path,
275
+ config_path=regtr_config_path,
276
+ )
277
+ regtr_results_all.append(regtr_results)
278
+ regtr_reg_results_all.append(regtr_results_pc)
279
+
280
+ # R3PM-Net (ours) - no training
281
+ r3pm_net_results_pc, r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
282
+ noisy_source, target, 'r3pmnet', gt_transformation, r3pm_net_args)
283
+ r3pm_net_results_all.append(r3pm_net_results)
284
+ r3pm_net_reg_results_all.append(r3pm_net_results_pc)
285
+
286
+ # R3PM-Net (ours) (Tuned on 4 sioux data)
287
+ tuned_r3pm_net_results_pc, tuned_r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
288
+ noisy_source, target, 'r3pmnet', gt_transformation, tuned_r3pm_net_args)
289
+ tuned_r3pm_net_results_all.append(tuned_r3pm_net_results)
290
+ tuned_r3pm_net_reg_results_all.append(tuned_r3pm_net_results_pc)
291
+
292
+
293
+ all_sources.append(noisy_source)
294
+ all_targets.append(target)
295
+ all_angles[i] = {
296
+ "x_angle": x_angle,
297
+ "y_angle": y_angle,
298
+ "z_angle": z_angle,
299
+ "translation": gt_transformation[:3, 3]
300
+ }
301
+
302
+ # Convert results to numpy arrays for easier manipulation
303
+ rpm_results_all = np.array(rpm_results_all)
304
+ predator_results_all = np.array(predator_results_all)
305
+ geotransformer_results_all = np.array(geotransformer_results_all)
306
+ logdesc_results_all = np.array(logdesc_results_all)
307
+ regtr_results_all = np.array(regtr_results_all)
308
+ r3pm_net_results_all = np.array(r3pm_net_results_all)
309
+ tuned_r3pm_net_results_all = np.array(tuned_r3pm_net_results_all)
310
+
311
+ rpm_mean_results = np.mean(rpm_results_all, axis=0)
312
+ predator_mean_results = np.mean(predator_results_all, axis=0)
313
+ geotransformer_mean_results = np.mean(geotransformer_results_all, axis=0)
314
+ logdesc_mean_results = np.mean(logdesc_results_all, axis=0)
315
+ regtr_mean_results = np.mean(regtr_results_all, axis=0)
316
+ r3pm_net_mean_results = np.mean(r3pm_net_results_all, axis=0)
317
+ tuned_r3pm_net_mean_results = np.mean(tuned_r3pm_net_results_all, axis=0)
318
+
319
+ # Print the results
320
+ metric_names = ['mean_rmse', 'mean_rotation_error', 'mean_translation_error',
321
+ 'mean_computation_time', 'mean_cd', 'mean_error',
322
+ 'mean_fitness', 'mean_inlier_rmse']
323
+
324
+ reports = {
325
+ "RPMNet": dict(zip(metric_names, rpm_mean_results)),
326
+ "Predator": dict(zip(metric_names, predator_mean_results)),
327
+ "GeoTransformer": dict(zip(metric_names, geotransformer_mean_results)),
328
+ "LoGDesc": dict(zip(metric_names, logdesc_mean_results)),
329
+ "RegTR": dict(zip(metric_names, regtr_mean_results)),
330
+ "R3PM-Net (ours) (ZS)": dict(zip(metric_names, r3pm_net_mean_results)),
331
+ "R3PM-Net (ours) (FT)": dict(zip(metric_names, tuned_r3pm_net_mean_results)),
332
+ }
333
+
334
+ # Print the table
335
+ print_results.print_table(reports)
scripts/eval_sioux_cranfield.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import open3d as o3d
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import sys
7
+ from pathlib import Path
8
+ import torch
9
+ import random
10
+ import argparse
11
+
12
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
13
+ if str(_REPO_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_REPO_ROOT))
15
+
16
+ from tools import augmentation, data, l3d_helper, print_results, transformations
17
+ from tools import l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation
18
+ from r3pm_net.config_loader import get_method_paths, get_pretrained_rpmnet_dir, get_sioux_data_root, get_sioux_paths
19
+ '''
20
+ This script evaluates the performance on a Sioux-Cranfield dataset
21
+ Cranfield dataset from: https://github.com/Menthy-Denayer/PCR_CAD_Model_Alignment_Comparison/tree/main/datasets
22
+ '''
23
+ def set_seed(seed: int) -> None:
24
+ os.environ["PYTHONHASHSEED"] = str(seed)
25
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
26
+
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed_all(seed)
31
+
32
+ torch.backends.cudnn.benchmark = False
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.use_deterministic_algorithms(True)
35
+
36
+ # arguments
37
+ parser = argparse.ArgumentParser(description="Sioux-Cranfield R3PM-Net evaluation")
38
+ parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)")
39
+ args = parser.parse_args()
40
+ set_seed(args.seed)
41
+
42
+ base_dir = get_sioux_data_root()
43
+ sioux_cfg = get_sioux_paths()
44
+ method_paths = get_method_paths()
45
+
46
+ pretrained_base_dir = get_pretrained_rpmnet_dir()
47
+ _path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth")
48
+ _path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet.t7") #TODO: CHANGE
49
+
50
+ # Paths to the CAD models
51
+ cad_dir_made = os.path.join(base_dir, 'sioux_cranfield')
52
+
53
+ cad_paths = [os.path.join(cad_dir_made, 'Base-Top_Plate.stl'),
54
+ os.path.join(cad_dir_made, 'Pendulum.stl'),
55
+ os.path.join(cad_dir_made, 'Round-Peg.stl'),
56
+ os.path.join(cad_dir_made, 'Separator.stl'),
57
+ os.path.join(cad_dir_made, 'Shaft-New.stl'),
58
+ os.path.join(cad_dir_made, 'Square-Peg.stl'),
59
+ os.path.join(cad_dir_made, 'elephant.stl'),
60
+ os.path.join(cad_dir_made, 'house.stl'),
61
+ os.path.join(cad_dir_made, 'shoe.stl')]
62
+
63
+ # Test parameters
64
+ num_tests = 25
65
+ angles = list(range(0, 45))
66
+ translation_range = (-0.5, 0.5)
67
+ np.random.seed(42)
68
+
69
+ # Augmentation parameters
70
+ noise_level = 0
71
+ outlier_level = 0
72
+ outlier_lowerbound = -0.5
73
+ outlier_upperbound = 0.5
74
+ # occlusion_level = 9000 # Higher value means less occlusion
75
+ occ_level = 0
76
+
77
+ # Make dataset
78
+ sources = []
79
+ targets = []
80
+ x_angles = []
81
+ y_angles = []
82
+ z_angles = []
83
+ gt_transformations = []
84
+
85
+ for cadPath in tqdm (cad_paths, desc="Preparing Sioux-Cranfield Dataset", total=len(cad_paths)):
86
+
87
+ num_points = 2000
88
+ # Load the data
89
+ mesh = o3d.io.read_triangle_mesh(cadPath)
90
+ cad = mesh.sample_points_poisson_disk(number_of_points=num_points) # modify to a suitable number of points
91
+ normalized_point_cloud = data.normalize_pc(cad)
92
+ source = copy.deepcopy(normalized_point_cloud)
93
+
94
+ for test in range(num_tests):
95
+ # Data simulation
96
+ x_angle= np.random.uniform(angles[0], angles[-1], size=1)
97
+ y_angle= np.random.uniform(angles[0], angles[-1], size=1)
98
+ z_angle= np.random.uniform(angles[0], angles[-1], size=1)
99
+ gt_transformation = transformations.create_transformation(x_angle, y_angle, z_angle, translation_range)
100
+ target = copy.deepcopy(normalized_point_cloud).transform(gt_transformation)
101
+
102
+ # Data augmentation
103
+ if occ_level == 0 and noise_level == 0 and outlier_level == 0:
104
+ noisy_source = copy.deepcopy(source)
105
+
106
+ # Noise + Occlusion
107
+ elif occ_level != 0 and noise_level != 0:
108
+ noisy_source_noise = augmentation.apply_noise(source, noise_level)
109
+ noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level)
110
+ if len(noisy_source.points) < 1024: # Handle excessive occlusion
111
+ source = copy.deepcopy(target).transform(gt_transformation)
112
+ noisy_source_noise = augmentation.apply_noise(source, noise_level)
113
+ noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level * 1.5)
114
+
115
+ # Noise + Outlier
116
+ elif noise_level != 0 and outlier_level != 0:
117
+ noisy_source_noise = augmentation.apply_noise(source, noise_level)
118
+ noisy_source = augmentation.add_outliers(noisy_source_noise, outlier_level, outlier_lowerbound=-0.5, outlier_upperbound=0.5)
119
+
120
+ # Noise + Outlier + Occlusion
121
+ elif occ_level != 0 and noise_level != 0 and outlier_level != 0:
122
+ noisy_source_noise = augmentation.apply_noise(source, noise_level)
123
+ noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level)
124
+ if len(noisy_source.points) < 1024: # Handle excessive occlusion
125
+ source = copy.deepcopy(target).transform(gt_transformation)
126
+ noisy_source_noise = augmentation.apply_noise(source, noise_level)
127
+ noisy_source, _ = augmentation.apply_occlusion(noisy_source_noise, occ_level * 1.5)
128
+ noisy_source = augmentation.add_outliers(noisy_source, outlier_level, outlier_lowerbound=-0.5, outlier_upperbound=0.5)
129
+
130
+ # collect dataset in lists
131
+ sources.append(noisy_source)
132
+ targets.append(target)
133
+ x_angles.append(x_angle)
134
+ y_angles.append(y_angle)
135
+ z_angles.append(z_angle)
136
+ gt_transformations.append(gt_transformation)
137
+
138
+ # Initialize arrays to store results
139
+ rpm_results_all = []
140
+ predator_results_all = []
141
+ geotransformer_results_all = []
142
+ logdesc_results_all = []
143
+ regtr_results_all = []
144
+ r3pm_net_results_all = []
145
+ tuned_r3pm_net_results_all = []
146
+
147
+ rpm_reg_results_all = []
148
+ predator_reg_results_all = []
149
+ geotransformer_reg_results_all = []
150
+ logdesc_reg_results_all = []
151
+ regtr_reg_results_all = []
152
+ r3pm_net_reg_results_all = []
153
+ tuned_r3pm_net_reg_results_all = []
154
+
155
+ # set arguments for models
156
+ rpm_args = l3d_helper.options(modelName="RPMNet")
157
+ rpm_args.pretrained = _path_zs
158
+
159
+ # OverlapPredator (used by Predator runner)
160
+ predator_cfg = method_paths.get("predator", {})
161
+ predator_root = predator_cfg.get("root")
162
+ predator_config_path = predator_cfg.get("config_path")
163
+ predator_weights_path = predator_cfg.get("weights_path")
164
+
165
+ # GeoTransformer
166
+ geo_cfg = method_paths.get("geotransformer", {})
167
+ geotransformer_root = geo_cfg.get("root")
168
+ geotransformer_exp_subdir = geo_cfg.get("exp_subdir")
169
+ geotransformer_weights_path = geo_cfg.get("weights_path")
170
+
171
+ # LoGDesc
172
+ logdesc_cfg = method_paths.get("logdesc", {})
173
+ logdesc_root = logdesc_cfg.get("root")
174
+ logdesc_weights_path = logdesc_cfg.get("weights_path")
175
+
176
+ # RegTR
177
+ regtr_cfg = method_paths.get("regtr", {})
178
+ regtr_root = regtr_cfg.get("root")
179
+ regtr_ckpt_path = regtr_cfg.get("ckpt_path")
180
+ regtr_config_path = regtr_cfg.get("config_path")
181
+
182
+ # R3PM-Net (ours) - ZS - no training
183
+ r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
184
+ r3pm_net_args.pretrained = _path_zs
185
+
186
+ # R3PM-Net (ours) - FT
187
+ tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
188
+ tuned_r3pm_net_args.pretrained = _path_ft
189
+
190
+
191
+ for i, item in enumerate(tqdm(zip(sources, targets, gt_transformations), desc="Testing methods", total=len(sources))):
192
+
193
+ # RPMNet
194
+ rpm_results_pc, rpm_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
195
+ sources[i], targets[i], 'rpmnet', gt_transformations[i], rpm_args)
196
+ rpm_results_all.append(rpm_results)
197
+ rpm_reg_results_all.append(rpm_results_pc)
198
+
199
+ # OverlapPredator
200
+ predator_results_pc, predator_results = predator_registration_and_evaluation.predator_reg_and_eval(
201
+ sources[i],
202
+ targets[i],
203
+ gt_transformation=gt_transformations[i],
204
+ predator_root=predator_root,
205
+ config_path=predator_config_path,
206
+ weights_path=predator_weights_path,
207
+ ransac_n_points=1000,
208
+ ransac_distance_threshold=0.05,
209
+ ransac_n=3,
210
+ sampling="prob",
211
+ mutual=False,
212
+ input_num_points=1024,
213
+ )
214
+ predator_results_all.append(predator_results)
215
+ predator_reg_results_all.append(predator_results_pc)
216
+
217
+ # GeoTransformer (ModelNet)
218
+ geotransformer_results_pc, geotransformer_results = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval(
219
+ sources[i],
220
+ targets[i],
221
+ gt_transformation=gt_transformations[i],
222
+ geotransformer_root=geotransformer_root,
223
+ exp_subdir=geotransformer_exp_subdir,
224
+ weights_path=geotransformer_weights_path,
225
+ )
226
+ geotransformer_results_all.append(geotransformer_results)
227
+ geotransformer_reg_results_all.append(geotransformer_results_pc)
228
+
229
+ # LoGDesc
230
+ logdesc_results_pc, logdesc_results = logdesc_registration_and_evaluation.logdesc_reg_and_eval(
231
+ sources[i],
232
+ targets[i],
233
+ gt_transformation=gt_transformations[i],
234
+ logdesc_root=logdesc_root,
235
+ weights_path=logdesc_weights_path,
236
+ max_keypoints=768,
237
+ num_points_per_sample=128,
238
+ sample_radius=0.3,
239
+ topk_matches=128,
240
+ use_kpt=False,
241
+ )
242
+ logdesc_results_all.append(logdesc_results)
243
+ logdesc_reg_results_all.append(logdesc_results_pc)
244
+
245
+ # RegTR (ModelNet)
246
+ regtr_results_pc, regtr_results = regtr_registration_and_evaluation.regtr_reg_and_eval(
247
+ sources[i],
248
+ targets[i],
249
+ gt_transformation=gt_transformations[i],
250
+ regtr_root=regtr_root,
251
+ ckpt_path=regtr_ckpt_path,
252
+ config_path=regtr_config_path,
253
+ )
254
+ regtr_results_all.append(regtr_results)
255
+ regtr_reg_results_all.append(regtr_results_pc)
256
+
257
+ # R3PM-Net (ours) - ZS - no training
258
+ r3pm_net_results_pc, r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
259
+ sources[i], targets[i], 'r3pmnet', gt_transformations[i], r3pm_net_args)
260
+ r3pm_net_results_all.append(r3pm_net_results)
261
+ r3pm_net_reg_results_all.append(r3pm_net_results_pc)
262
+
263
+ # R3PM-Net (ours) - FT
264
+ tuned_r3pm_net_results_pc, tuned_r3pm_net_results = l3d_registration_and_evaluation.l3d_reg_and_eval(
265
+ sources[i], targets[i], 'r3pmnet', gt_transformations[i], tuned_r3pm_net_args)
266
+ tuned_r3pm_net_results_all.append(tuned_r3pm_net_results)
267
+ tuned_r3pm_net_reg_results_all.append(tuned_r3pm_net_results_pc)
268
+
269
+
270
+ # Convert results to numpy arrays for easier manipulation
271
+ rpm_results_all = np.array(rpm_results_all)
272
+ predator_results_all = np.array(predator_results_all)
273
+ geotransformer_results_all = np.array(geotransformer_results_all)
274
+ logdesc_results_all = np.array(logdesc_results_all)
275
+ regtr_results_all = np.array(regtr_results_all)
276
+ r3pm_net_results_all = np.array(r3pm_net_results_all)
277
+ tuned_r3pm_net_results_all = np.array(tuned_r3pm_net_results_all)
278
+
279
+ rpm_mean_results = np.mean(rpm_results_all, axis=0)
280
+ predator_mean_results = np.mean(predator_results_all, axis=0)
281
+ geotransformer_mean_results = np.mean(geotransformer_results_all, axis=0)
282
+ logdesc_mean_results = np.mean(logdesc_results_all, axis=0)
283
+ regtr_mean_results = np.mean(regtr_results_all, axis=0)
284
+ r3pm_net_mean_results = np.mean(r3pm_net_results_all, axis=0)
285
+ tuned_r3pm_net_mean_results = np.mean(tuned_r3pm_net_results_all, axis=0)
286
+
287
+ # Print the results
288
+ metric_names = ['mean_rmse', 'mean_rotation_error', 'mean_translation_error',
289
+ 'mean_computation_time', 'mean_cd', 'mean_error',
290
+ 'mean_fitness', 'mean_inlier_rmse']
291
+
292
+ reports = {
293
+ "RPMNet": dict(zip(metric_names, rpm_mean_results)),
294
+ "Predator": dict(zip(metric_names, predator_mean_results)),
295
+ "GeoTransformer": dict(zip(metric_names, geotransformer_mean_results)),
296
+ "LoGDesc": dict(zip(metric_names, logdesc_mean_results)),
297
+ "RegTR": dict(zip(metric_names, regtr_mean_results)),
298
+ "R3PM-Net (ours) (ZS)": dict(zip(metric_names, r3pm_net_mean_results)),
299
+ "R3PM-Net (ours) (FT)": dict(zip(metric_names, tuned_r3pm_net_mean_results)),}
300
+
301
+ # Print the table
302
+ print_results.print_table(reports)
scripts/eval_sioux_scans.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import argparse
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+ from tabulate import tabulate
8
+ from tqdm import tqdm
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
13
+ if str(_REPO_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_REPO_ROOT))
15
+
16
+ from tools import data, l3d_helper, visualization
17
+ from tools import icp_registration_and_evaluation, l3d_registration_and_evaluation, predator_registration_and_evaluation, geotransformer_registration_and_evaluation, logdesc_registration_and_evaluation, regtr_registration_and_evaluation
18
+ from r3pm_net.config_loader import get_pretrained_rpmnet_dir, get_sioux_data_root, get_method_paths
19
+
20
+ '''
21
+ This script is used to evaluate the performance of the pipeline with R3PM-Net as global and GICP as local registeration.
22
+
23
+ The script takes the following arguments:
24
+ --local_reg: the local registration method to be used.
25
+ --seed: random seed for python/numpy/torch. The default is 42.
26
+ --verbose: if set to True, the results will be printed in a table format. The default is False.
27
+ '''
28
+ def set_seed(seed: int) -> None:
29
+ os.environ["PYTHONHASHSEED"] = str(seed)
30
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
31
+
32
+ random.seed(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed_all(seed)
36
+
37
+ torch.backends.cudnn.benchmark = False
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.use_deterministic_algorithms(True)
40
+
41
+
42
+ # arguments
43
+ parser = argparse.ArgumentParser(description="Choosing local registration method")
44
+ parser.add_argument(
45
+ "--local_reg", type=str, default="gicp", help="local registration: gicp or freg"
46
+ )
47
+ parser.add_argument("--seed", type=int, default=42, help="random seed (default: 42)")
48
+
49
+ args = parser.parse_args()
50
+ set_seed(args.seed)
51
+ print(f"Using {args.local_reg} for local registration")
52
+
53
+ def analyze_results(results: dict, recall_threshold = 1, rmse_threshold = 0.053, verbose = False): # change the default values to your needs
54
+ table = []
55
+ fail_count = 0
56
+ success_count = 0
57
+ for object, values in results.items():
58
+ row = [object] + list(values)
59
+ if round(row[2], 3) < recall_threshold or round(row[3], 3) > rmse_threshold:
60
+ status = 'failed'
61
+ fail_count += 1
62
+ print(f'No match for {object}! Try a different method. If the issue persists, please check the data.')
63
+ else:
64
+ status = 'success'
65
+ success_count += 1
66
+ print(f'Found match for {object}!')
67
+ row.append(status)
68
+ table.append(row)
69
+
70
+ if verbose:
71
+ print(tabulate(table, headers=['Object', 'Chamfer Distance', 'Reg. Recall', 'Inlier RMSE', 'Computation Time', 'Status'], tablefmt='grid'))
72
+ print(f"Success rate: {success_count / (success_count + fail_count) * 100:.2f}%")
73
+
74
+ return table
75
+
76
+ def show_successful_resutls(table, sources, targets, pc_results, method_name = None):
77
+ for i in range (len(table)):
78
+ if table[i][-1] == 'success':
79
+ # visualization.plot_point_cloud(sources[i], targets[i], list(pc_results.values())[i]) # uncomment if below visualization does not work
80
+ visualization.draw_registration_result(targets[i], list(pc_results.values())[i], np.eye(4), method_name)
81
+
82
+ def main():
83
+ base_dir = get_sioux_data_root()
84
+ scan_dir = os.path.join(base_dir, 'sioux_scans')
85
+ cad_dir = os.path.join(base_dir, 'sioux_cranfield')
86
+
87
+ pcd_paths = [ os.path.join(scan_dir,'teeth_clean.ply'),
88
+ os.path.join(scan_dir,'lime_clean.ply'),
89
+ os.path.join(scan_dir,'cube_clean.ply'),
90
+ os.path.join(scan_dir,'lego_clean.ply'),
91
+ os.path.join(scan_dir,'elephant_clean.ply'),
92
+ os.path.join(scan_dir,'house_clean.ply'),
93
+ os.path.join(scan_dir,'shoe_clean.ply')]
94
+
95
+ cad_paths = [ os.path.join(cad_dir,'teeth.stl'),
96
+ os.path.join(cad_dir,'lime.stl'),
97
+ os.path.join(cad_dir,'cube.stl'),
98
+ os.path.join(cad_dir,'lego.stl'),
99
+ os.path.join(cad_dir,'elephant.stl'),
100
+ os.path.join(cad_dir,'house.stl'),
101
+ os.path.join(cad_dir,'shoe.stl')]
102
+
103
+ # Initialize lists and dictionaries to store results
104
+ rpm_net_results = {}
105
+ rpm_net_pc_results = {}
106
+ predator_results = {}
107
+ predator_pc_results = {}
108
+ geotransformer_results = {}
109
+ geotransformer_pc_results = {}
110
+ logdesc_results = {}
111
+ logdesc_pc_results = {}
112
+ regtr_results = {}
113
+ regtr_pc_results = {}
114
+ r3pm_net_results = {}
115
+ r3pm_net_pc_results ={}
116
+ tuned_r3pm_net_results = {}
117
+ tuned_r3pm_net_pc_results = {}
118
+ subset_tuned_r3pm_net_results = {}
119
+ subset_tuned_r3pm_net_pc_results = {}
120
+
121
+ sources = []
122
+ targets = []
123
+
124
+ pretrained_base_dir = get_pretrained_rpmnet_dir()
125
+ method_paths = get_method_paths()
126
+ _path_zs = os.path.join(pretrained_base_dir, "clean-trained.pth")
127
+ _path_ft = os.path.join(pretrained_base_dir, "best_model_PointNet2.t7") #TODO: CHANGE
128
+ _path_ft_sub = os.path.join(pretrained_base_dir, "best_model_PointNet_subset.t7") #TODO: CHANGE
129
+
130
+ # set arguments for models
131
+ rpm_args = l3d_helper.options(modelName="RPMNet")
132
+ rpm_args.pretrained = _path_zs
133
+
134
+ # OverlapPredator (used by Predator runner)
135
+ predator_cfg = method_paths.get("predator", {})
136
+ predator_root = predator_cfg.get("root")
137
+ predator_config_path = predator_cfg.get("config_path")
138
+ predator_weights_path = predator_cfg.get("weights_path")
139
+
140
+ # GeoTransformer
141
+ geo_cfg = method_paths.get("geotransformer", {})
142
+ geotransformer_root = geo_cfg.get("root")
143
+ geotransformer_exp_subdir = geo_cfg.get("exp_subdir")
144
+ geotransformer_weights_path = geo_cfg.get("weights_path")
145
+
146
+ # LoGDesc
147
+ logdesc_cfg = method_paths.get("logdesc", {})
148
+ logdesc_root = logdesc_cfg.get("root")
149
+ logdesc_weights_path = logdesc_cfg.get("weights_path")
150
+
151
+ # RegTR
152
+ regtr_cfg = method_paths.get("regtr", {})
153
+ regtr_root = regtr_cfg.get("root")
154
+ regtr_ckpt_path = regtr_cfg.get("ckpt_path")
155
+ regtr_config_path = regtr_cfg.get("config_path")
156
+
157
+ # R3PM-Net (ours) - no training
158
+ r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
159
+ r3pm_net_args.pretrained = _path_zs
160
+
161
+ # R3PM-Net (ours) (FT)
162
+ tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
163
+ tuned_r3pm_net_args.pretrained = _path_ft
164
+
165
+ # R3PM-Net (ours) (FT) (Subset)
166
+ subset_tuned_r3pm_net_args = l3d_helper.options(modelName="R3PMNet")
167
+ subset_tuned_r3pm_net_args.pretrained = _path_ft_sub
168
+
169
+ for pcdPath, cadPath in tqdm(zip(pcd_paths, cad_paths), desc="Registering objects", total=len(pcd_paths)):
170
+ # Define the number of points to sample from the CAD model (change this based on your data)
171
+ if 'teeth' in pcdPath:
172
+ every_k_points = 100
173
+ key = 'teeth'
174
+ elif'lime' in pcdPath:
175
+ every_k_points = 100
176
+ key = 'lime'
177
+ elif 'cube' in pcdPath:
178
+ every_k_points = 1
179
+ key = 'cube'
180
+ elif 'lego' in pcdPath:
181
+ every_k_points = 10
182
+ key = 'lego'
183
+ elif 'elephant' in pcdPath:
184
+ every_k_points = 30
185
+ key = 'elephant'
186
+ elif 'house' in pcdPath:
187
+ every_k_points = 25
188
+ key = 'house'
189
+ elif 'shoe' in pcdPath:
190
+ every_k_points = 15
191
+ key = 'shoe'
192
+ else:
193
+ print("Unknown object type, using default every_k_points = 1")
194
+ every_k_points = 1
195
+
196
+ # Load the data
197
+ pcd, cad = data.load_data(pcdPath, cadPath, every_k_points=every_k_points)
198
+ source = copy.deepcopy(pcd)
199
+ target = copy.deepcopy(cad)
200
+
201
+ # Normalize the point clouds
202
+ source = data.normalize_pc(source)
203
+ target = data.normalize_pc(target)
204
+
205
+ sources.append(source)
206
+ targets.append(target)
207
+
208
+ gt_transformation = None
209
+
210
+ # Perform the registration
211
+
212
+ # RPMNet
213
+ rpm_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(
214
+ source, target, 'rpmnet', gt_transformation, rpm_args)
215
+ if args.local_reg == 'gicp':
216
+ final_rpm_net_pc_result, final_rpm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(rpm_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
217
+ rpm_net_results[key] = final_rpm_net_results
218
+ rpm_net_pc_results[key] = final_rpm_net_pc_result
219
+
220
+ # OverlapPredator
221
+ predator_results_pc, _ = predator_registration_and_evaluation.predator_reg_and_eval(
222
+ source,
223
+ target,
224
+ gt_transformation=gt_transformation,
225
+ predator_root=predator_root,
226
+ config_path=predator_config_path,
227
+ weights_path=predator_weights_path,
228
+ ransac_n_points=1000,
229
+ ransac_distance_threshold=0.05,
230
+ ransac_n=3,
231
+ sampling="prob",
232
+ mutual=False,
233
+ input_num_points=1024,
234
+ )
235
+ if args.local_reg == 'gicp':
236
+ final_predator_pc_result, final_predator_results = icp_registration_and_evaluation.icp_reg_and_eval(predator_results_pc, target, 'gicp', 1, np.identity(4), gt_transformation)
237
+ predator_results[key] = final_predator_results
238
+ predator_pc_results[key] = final_predator_pc_result
239
+
240
+ # GeoTransformer (ModelNet)
241
+ geotransformer_pc_result, _ = geotransformer_registration_and_evaluation.geotransformer_reg_and_eval(
242
+ source,
243
+ target,
244
+ gt_transformation=gt_transformation,
245
+ geotransformer_root=geotransformer_root,
246
+ exp_subdir=geotransformer_exp_subdir,
247
+ weights_path=geotransformer_weights_path,
248
+ )
249
+ if args.local_reg == 'gicp':
250
+ final_geotransformer_pc_result, final_geotransformer_results = icp_registration_and_evaluation.icp_reg_and_eval(geotransformer_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
251
+ geotransformer_results[key] = final_geotransformer_results
252
+ geotransformer_pc_results[key] = final_geotransformer_pc_result
253
+
254
+ # LoGDesc
255
+ logdesc_pc_result, _ = logdesc_registration_and_evaluation.logdesc_reg_and_eval(
256
+ source,
257
+ target,
258
+ gt_transformation=gt_transformation,
259
+ logdesc_root=logdesc_root,
260
+ weights_path=logdesc_weights_path,
261
+ max_keypoints=768,
262
+ num_points_per_sample=128,
263
+ sample_radius=0.3,
264
+ topk_matches=128,
265
+ use_kpt=False,
266
+ )
267
+ if args.local_reg == 'gicp':
268
+ final_logdesc_pc_result, final_logdesc_results = icp_registration_and_evaluation.icp_reg_and_eval(logdesc_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
269
+ logdesc_results[key] = final_logdesc_results
270
+ logdesc_pc_results[key] = final_logdesc_pc_result
271
+
272
+ # RegTR (ModelNet)
273
+ regtr_pc_result, _ = regtr_registration_and_evaluation.regtr_reg_and_eval(
274
+ source,
275
+ target,
276
+ gt_transformation=gt_transformation,
277
+ regtr_root=regtr_root,
278
+ ckpt_path=regtr_ckpt_path,
279
+ config_path=regtr_config_path,
280
+ )
281
+ if args.local_reg == 'gicp':
282
+ final_regtr_pc_result, final_regtr_results = icp_registration_and_evaluation.icp_reg_and_eval(regtr_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
283
+ regtr_results[key] = final_regtr_results
284
+ regtr_pc_results[key] = final_regtr_pc_result
285
+
286
+ # R3PM-Net (ours) (ZS)
287
+ r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, r3pm_net_args)
288
+ if args.local_reg == 'gicp':
289
+ final_r3pm_net_pc_result, final_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
290
+ r3pm_net_results[key] = final_r3pm_net_results
291
+ r3pm_net_pc_results[key] = final_r3pm_net_pc_result
292
+
293
+ # R3PM-Net (ours) (FT)
294
+ tuned_r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, tuned_r3pm_net_args)
295
+ if args.local_reg == 'gicp':
296
+ final_tuned_r3pm_net_pc_result, final_tuned_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(tuned_r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
297
+ tuned_r3pm_net_results[key] = final_tuned_r3pm_net_results
298
+ tuned_r3pm_net_pc_results[key] = final_tuned_r3pm_net_pc_result
299
+
300
+ # R3PM-Net (ours) (FT) (Subset)
301
+ subset_tuned_r3pm_net_pc_result, _ = l3d_registration_and_evaluation.l3d_reg_and_eval(source, target, 'r3pmnet', gt_transformation, subset_tuned_r3pm_net_args)
302
+ if args.local_reg == 'gicp':
303
+ final_subset_tuned_r3pm_net_pc_result, final_subset_tuned_r3pm_net_results = icp_registration_and_evaluation.icp_reg_and_eval(subset_tuned_r3pm_net_pc_result, target, 'gicp', 1, np.identity(4), gt_transformation)
304
+ subset_tuned_r3pm_net_results[key] = final_subset_tuned_r3pm_net_results
305
+ subset_tuned_r3pm_net_pc_results[key] = final_subset_tuned_r3pm_net_pc_result
306
+
307
+ # Print the results
308
+ print("----- RPMNet: -----")
309
+ rpm_net_table = analyze_results(rpm_net_results, verbose=True)
310
+ show_successful_resutls(rpm_net_table, sources, targets, rpm_net_pc_results, 'RPMNet')
311
+
312
+ print("----- Predator: -----")
313
+ predator_table = analyze_results(predator_results, verbose=True)
314
+ show_successful_resutls(predator_table, sources, targets, predator_pc_results, 'Predator')
315
+
316
+ print("----- GeoTransformer: -----")
317
+ geotransformer_table = analyze_results(geotransformer_results, verbose=True)
318
+ show_successful_resutls(geotransformer_table, sources, targets, geotransformer_pc_results, 'GeoTransformer')
319
+
320
+ print("----- LoGDesc: -----")
321
+ logdesc_table = analyze_results(logdesc_results, verbose=True)
322
+ show_successful_resutls(logdesc_table, sources, targets, logdesc_pc_results, 'LoGDesc')
323
+
324
+ print("----- RegTR: -----")
325
+ regtr_table = analyze_results(regtr_results, verbose=True)
326
+ show_successful_resutls(regtr_table, sources, targets, regtr_pc_results, 'RegTR')
327
+
328
+ print("----- R3PM-Net (ours) (ZS): -----")
329
+ r3pm_net_table = analyze_results(r3pm_net_results, verbose=True)
330
+ show_successful_resutls(r3pm_net_table, sources, targets, r3pm_net_pc_results, 'R3PM-Net (ours) (ZS)')
331
+
332
+ print("----- R3PM-Net (ours) (FT): ----- ")
333
+ tuned_r3pm_net_table = analyze_results(tuned_r3pm_net_results, verbose=True)
334
+ show_successful_resutls(tuned_r3pm_net_table, sources, targets, tuned_r3pm_net_pc_results, 'R3PM-Net (ours) (FT)')
335
+
336
+ print("----- R3PM-Net (ours) (FT) (Subset): ----- ")
337
+ subset_tuned_r3pm_net_table = analyze_results(subset_tuned_r3pm_net_results, verbose=True)
338
+ show_successful_resutls(subset_tuned_r3pm_net_table, sources, targets, subset_tuned_r3pm_net_pc_results, 'R3PM-Net (ours) (FT) (Subset)')
339
+
340
+ if __name__ == "__main__":
341
+ main()
scripts/modelnet40.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=gpu_h100
3
+ #SBATCH --gpus=1
4
+ #SBATCH --job-name=modelnet40
5
+ #SBATCH --ntasks=1
6
+ #SBATCH --time=09:00:00
7
+ #SBATCH --output=modelnet40_output_%A.txt
8
+ #SBATCH --error=modelnet40_error_%A.txt
9
+
10
+ # Load necessary modules (adjust based on your environment)
11
+ module purge
12
+ module load 2023
13
+ module load CUDA/12.1.1
14
+
15
+ # my miniconda3 path
16
+ export PATH="$HOME/miniconda3/bin:$PATH"
17
+ unset -f conda 2>/dev/null
18
+ source "$HOME/miniconda3/etc/profile.d/conda.sh"
19
+
20
+ # Activate the conda environment
21
+ conda activate r3pm_net
22
+
23
+ if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
24
+ REPO_ROOT="$(cd "${SLURM_SUBMIT_DIR}" && pwd)"
25
+ else
26
+ REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
27
+ fi
28
+ cd "$REPO_ROOT" || { echo "ERROR: cannot cd to REPO_ROOT=${REPO_ROOT}" >&2; exit 1; }
29
+ if [[ ! -f "${REPO_ROOT}/pyproject.toml" ]]; then
30
+ echo "ERROR: REPO_ROOT=${REPO_ROOT} is not the r3pm_net tree (missing pyproject.toml)." >&2
31
+ echo "Run: cd /path/to/r3pm_net && sbatch scripts/modelnet40.sh" >&2
32
+ exit 1
33
+ fi
34
+
35
+ LOGDIR="${REPO_ROOT}/logs/slurm"
36
+ mkdir -p "$LOGDIR"
37
+ JOB_ID="${SLURM_JOB_ID:-local}"
38
+
39
+ # seeds=(42 61 92 114 123 456 789)
40
+ seeds=(42)
41
+
42
+ for seed in "${seeds[@]}"; do
43
+ srun python scripts/eval_modelnet40.py --seed "${seed}" \
44
+ >"${LOGDIR}/modelnet40_job${JOB_ID}_seed${seed}.log" 2>&1
45
+ done
scripts/sioux_cranfield.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=gpu_h100
3
+ #SBATCH --gpus=1
4
+ #SBATCH --job-name=sioux_cranfield
5
+ #SBATCH --ntasks=1
6
+ #SBATCH --time=04:00:00
7
+ #SBATCH --output=sioux_cranfield_output_%A.txt
8
+ #SBATCH --error=sioux_cranfield_error_%A.txt
9
+
10
+ # Load necessary modules (adjust based on your environment)
11
+ module purge
12
+ module load 2023
13
+ module load CUDA/12.1.1
14
+
15
+ # my miniconda3 path
16
+ export PATH="$HOME/miniconda3/bin:$PATH"
17
+ unset -f conda 2>/dev/null
18
+ source "$HOME/miniconda3/etc/profile.d/conda.sh"
19
+
20
+ # Activate the conda environment
21
+ conda activate r3pm_net
22
+
23
+
24
+ if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
25
+ REPO_ROOT="$(cd "${SLURM_SUBMIT_DIR}" && pwd)"
26
+ else
27
+ REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
28
+ fi
29
+ cd "$REPO_ROOT" || { echo "ERROR: cannot cd to REPO_ROOT=${REPO_ROOT}" >&2; exit 1; }
30
+ if [[ ! -f "${REPO_ROOT}/pyproject.toml" ]]; then
31
+ echo "ERROR: REPO_ROOT=${REPO_ROOT} is not the r3pm_net tree (missing pyproject.toml)." >&2
32
+ echo "Run: cd /path/to/r3pm_net && sbatch scripts/sioux_cranfield.sh" >&2
33
+ exit 1
34
+ fi
35
+
36
+ LOGDIR="${REPO_ROOT}/logs/slurm"
37
+ mkdir -p "$LOGDIR"
38
+ JOB_ID="${SLURM_JOB_ID:-local}"
39
+
40
+ # seeds=(42 61 92 114 123 456 789)
41
+ seeds=(42)
42
+
43
+ for seed in "${seeds[@]}"; do
44
+ srun python scripts/eval_sioux_cranfield.py --seed "${seed}" \
45
+ >"${LOGDIR}/sioux_cranfield_job${JOB_ID}_seed${seed}.log" 2>&1
46
+ done
scripts/sioux_scans.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=gpu_h100
3
+ #SBATCH --gpus=1
4
+ #SBATCH --job-name=sioux_scans
5
+ #SBATCH --ntasks=1
6
+ #SBATCH --time=01:00:00
7
+ #SBATCH --output=sioux_scans_output_%A.txt
8
+ #SBATCH --error=sioux_scans_error_%A.txt
9
+
10
+ # Load necessary modules (adjust based on your environment)
11
+ module purge
12
+ module load 2023
13
+ module load CUDA/12.1.1
14
+
15
+ # my miniconda3 path
16
+ export PATH="$HOME/miniconda3/bin:$PATH"
17
+ unset -f conda 2>/dev/null
18
+ source "$HOME/miniconda3/etc/profile.d/conda.sh"
19
+
20
+ # Activate the conda environment
21
+ conda activate r3pm_net
22
+
23
+ if [[ -n "${SLURM_SUBMIT_DIR:-}" ]]; then
24
+ REPO_ROOT="$(cd "${SLURM_SUBMIT_DIR}" && pwd)"
25
+ else
26
+ REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
27
+ fi
28
+ cd "$REPO_ROOT" || { echo "ERROR: cannot cd to REPO_ROOT=${REPO_ROOT}" >&2; exit 1; }
29
+ if [[ ! -f "${REPO_ROOT}/pyproject.toml" ]]; then
30
+ echo "ERROR: REPO_ROOT=${REPO_ROOT} is not the r3pm_net tree (missing pyproject.toml)." >&2
31
+ echo "Run: cd /path/to/r3pm_net && sbatch scripts/sioux_scans.sh" >&2
32
+ exit 1
33
+ fi
34
+
35
+ LOGDIR="${REPO_ROOT}/logs/slurm"
36
+ mkdir -p "$LOGDIR"
37
+ JOB_ID="${SLURM_JOB_ID:-local}"
38
+
39
+ # seeds=(42 61 92 114 123 456 789)
40
+ seeds=(42)
41
+
42
+ for seed in "${seeds[@]}"; do
43
+ srun python scripts/eval_sioux_scans.py --seed "${seed}" \
44
+ >"${LOGDIR}/sioux_scans_job${JOB_ID}_seed${seed}.log" 2>&1
45
+ done
src/train.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tensorboardX import SummaryWriter
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+
13
+ # Repository root on PYTHONPATH (for `python src/train.py` or srun).
14
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
15
+ if str(_REPO_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(_REPO_ROOT))
17
+
18
+ from r3pm_net.model import R3PMNet
19
+ from r3pm_net.config_loader import parse_train_args, resolve_path_args
20
+ from r3pm_net.paths import REPO_ROOT
21
+ from thirdparty.learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
22
+ from dataloader.user_data import UserData
23
+ from r3pm_net.feature_extractor import feature_extractor # import your feature extractor here
24
+
25
+ def _init_(args):
26
+ Path(args.save_dir).mkdir(parents=True, exist_ok=True)
27
+ (REPO_ROOT / "checkpoints" / args.exp_name).mkdir(parents=True, exist_ok=True)
28
+
29
+ if os.path.isfile("main.py"):
30
+ os.system("cp main.py checkpoints" + "/" + args.exp_name + "/" + "main.py.backup")
31
+ if os.path.isfile("model.py"):
32
+ os.system("cp model.py checkpoints" + "/" + args.exp_name + "/" + "model.py.backup")
33
+
34
+
35
+ class IOStream:
36
+ def __init__(self, path):
37
+ self.f = open(path, "a")
38
+
39
+ def cprint(self, text):
40
+ print(text)
41
+ self.f.write(text + "\n")
42
+ self.f.flush()
43
+
44
+ def close(self):
45
+ self.f.close()
46
+
47
+
48
+ def test_one_epoch(device, model, test_loader):
49
+ model.eval()
50
+ test_loss = 0.0
51
+ count = 0
52
+ for i, data in enumerate(tqdm(test_loader)):
53
+ template, source, igt = data
54
+
55
+ template = template.to(device)
56
+ source = source.to(device)
57
+ igt = igt.to(device)
58
+
59
+ output = model(template, source)
60
+ loss_val = FrobeniusNormLoss()(output["est_T"], igt) + RMSEFeaturesLoss()(output["r"])
61
+
62
+ test_loss += loss_val.item()
63
+ count += 1
64
+
65
+ test_loss = float(test_loss) / count
66
+ return test_loss
67
+
68
+
69
+ def test(args, model, test_loader, textio):
70
+ test_loss = test_one_epoch(args.device, model, test_loader)
71
+ textio.cprint("Validation Loss: %f" % (test_loss))
72
+
73
+
74
+ def train_one_epoch(device, model, train_loader, optimizer):
75
+ model.train()
76
+ train_loss = 0.0
77
+ count = 0
78
+ for i, data in enumerate(tqdm(train_loader)):
79
+ template, source, igt = data
80
+
81
+ template = template.to(device)
82
+ source = source.to(device)
83
+ igt = igt.to(device)
84
+
85
+ output = model(template, source)
86
+ loss_val = FrobeniusNormLoss()(output["est_T"], igt) + RMSEFeaturesLoss()(output["r"])
87
+
88
+ optimizer.zero_grad()
89
+ loss_val.backward()
90
+ optimizer.step()
91
+
92
+ train_loss += loss_val.item()
93
+ count += 1
94
+
95
+ train_loss = float(train_loss) / count
96
+ return train_loss
97
+
98
+
99
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
100
+ Path(args.save_dir).mkdir(parents=True, exist_ok=True)
101
+
102
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
103
+ if args.optimizer == "Adam":
104
+ optimizer = torch.optim.Adam(learnable_params)
105
+ else:
106
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
107
+
108
+ if checkpoint is not None:
109
+ optimizer.load_state_dict(checkpoint["optimizer"])
110
+
111
+ best_test_loss = np.inf
112
+
113
+ for epoch in range(args.start_epoch, args.epochs):
114
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
115
+ test_loss = test_one_epoch(args.device, model, test_loader)
116
+
117
+ snap = {
118
+ "epoch": epoch + 1,
119
+ "model": model.state_dict(),
120
+ "min_loss": test_loss,
121
+ "optimizer": optimizer.state_dict(),
122
+ }
123
+
124
+ if test_loss < best_test_loss:
125
+ best_test_loss = test_loss
126
+ best_snap_path = os.path.join(
127
+ args.save_dir, "best_model_snap.t7")
128
+ best_model_path = os.path.join(
129
+ args.save_dir, "best_model.t7")
130
+
131
+ torch.save(snap, best_snap_path)
132
+ torch.save(model.state_dict(), best_model_path)
133
+
134
+ torch.save(snap, os.path.join(args.save_dir, "model_snap.t7"))
135
+ torch.save(model.state_dict(), os.path.join(args.save_dir, "model.t7"))
136
+
137
+ boardio.add_scalar("Train Loss", train_loss, epoch + 1)
138
+ boardio.add_scalar("Test Loss", test_loss, epoch + 1)
139
+ boardio.add_scalar("Best Test Loss", best_test_loss, epoch + 1)
140
+
141
+ textio.cprint(
142
+ "EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f"
143
+ % (epoch + 1, train_loss, test_loss, best_test_loss)
144
+ )
145
+
146
+
147
+ def build_parser(default_config_path: str):
148
+ parser = argparse.ArgumentParser(description="Point Cloud Registration")
149
+ parser.add_argument(
150
+ "--config",
151
+ type=str,
152
+ default=default_config_path,
153
+ help="YAML file with defaults (see config/default.yaml); can be overridden on the command line",
154
+ )
155
+ parser.add_argument(
156
+ "--exp_name",
157
+ type=str,
158
+ default="exp_r3pmnet",
159
+ metavar="N",
160
+ help="Name of the experiment",
161
+ )
162
+ parser.add_argument("--eval", action="store_true", help="Run evaluation only (no training).")
163
+ parser.add_argument(
164
+ "--save_dir",
165
+ type=str,
166
+ default="",
167
+ help="Directory to save model checkpoints (default: checkpoints/<exp_name>/models)",
168
+ )
169
+
170
+ parser.add_argument(
171
+ "--num_points",
172
+ default=1024,
173
+ type=int,
174
+ metavar="N",
175
+ help="points in point-cloud (default: 1024)",
176
+ )
177
+
178
+ parser.add_argument(
179
+ "--fine_tune_feature_extractor",
180
+ default="tune",
181
+ type=str,
182
+ choices=["fixed", "tune"],
183
+ help="train feature extractor (default: tune)",
184
+ )
185
+ parser.add_argument(
186
+ "--transfer_weights",
187
+ default="",
188
+ type=str,
189
+ metavar="PATH",
190
+ help="optional path to feature extractor checkpoint",
191
+ )
192
+ parser.add_argument(
193
+ "--symfn",
194
+ default="max",
195
+ choices=["max", "avg"],
196
+ help="symmetric function (default: max)",
197
+ )
198
+
199
+ parser.add_argument("--seed", type=int, default=1234)
200
+ parser.add_argument(
201
+ "-j",
202
+ "--workers",
203
+ default=4,
204
+ type=int,
205
+ metavar="N",
206
+ help="number of data loading workers (default: 4)",
207
+ )
208
+ parser.add_argument(
209
+ "-b",
210
+ "--batch_size",
211
+ default=5,
212
+ type=int,
213
+ metavar="N",
214
+ help="mini-batch size (default: 5)",
215
+ )
216
+ parser.add_argument(
217
+ "--epochs",
218
+ default=50,
219
+ type=int,
220
+ metavar="N",
221
+ help="number of total epochs to run",
222
+ )
223
+ parser.add_argument(
224
+ "--start_epoch",
225
+ default=0,
226
+ type=int,
227
+ metavar="N",
228
+ help="manual epoch number (useful on restarts)",
229
+ )
230
+ parser.add_argument(
231
+ "--optimizer",
232
+ default="Adam",
233
+ choices=["Adam", "SGD"],
234
+ metavar="METHOD",
235
+ help="name of an optimizer (default: Adam)",
236
+ )
237
+ parser.add_argument(
238
+ "--resume",
239
+ default="",
240
+ type=str,
241
+ metavar="PATH",
242
+ help="path to latest checkpoint (default: none)",
243
+ )
244
+ parser.add_argument(
245
+ "--pretrained",
246
+ default="",
247
+ type=str,
248
+ metavar="PATH",
249
+ help="path to pretrained full model (default: none)",
250
+ )
251
+ parser.add_argument(
252
+ "--device",
253
+ default="cuda:0",
254
+ type=str,
255
+ metavar="DEVICE",
256
+ help="use CUDA if available",
257
+ )
258
+
259
+ parser.add_argument(
260
+ "--train_dict_path",
261
+ type=str,
262
+ default="data/simulators/data_dict_train.pkl",
263
+ help="Pickled training data_dict",
264
+ )
265
+ parser.add_argument(
266
+ "--test_dict_path",
267
+ type=str,
268
+ default="data/simulators/data_dict_test.pkl",
269
+ help="Pickled test data_dict",
270
+ )
271
+
272
+ return parser
273
+
274
+
275
+ def _torch_load(path, map_location):
276
+ try:
277
+ return torch.load(path, map_location=map_location, weights_only=False)
278
+ except TypeError:
279
+ return torch.load(path, map_location=map_location)
280
+
281
+
282
+ def main():
283
+ args = parse_train_args(sys.argv[1:], build_parser)
284
+
285
+ resolve_path_args(
286
+ args,
287
+ (
288
+ "save_dir",
289
+ "train_dict_path",
290
+ "test_dict_path",
291
+ "resume",
292
+ "pretrained",
293
+ "transfer_weights",
294
+ ),
295
+ )
296
+
297
+ if not args.save_dir:
298
+ args.save_dir = str(REPO_ROOT / "checkpoints" / args.exp_name / "models")
299
+
300
+ torch.backends.cudnn.deterministic = True
301
+ torch.manual_seed(args.seed)
302
+ torch.cuda.manual_seed_all(args.seed)
303
+ np.random.seed(args.seed)
304
+
305
+ ckpt_dir = REPO_ROOT / "checkpoints" / args.exp_name
306
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
307
+ boardio = SummaryWriter(log_dir=str(ckpt_dir))
308
+ _init_(args)
309
+
310
+ textio = IOStream(str(ckpt_dir / "run.log"))
311
+ textio.cprint(str(args))
312
+
313
+ if not os.path.isfile(args.train_dict_path):
314
+ raise FileNotFoundError(f"Training dict not found: {args.train_dict_path}")
315
+ if not os.path.isfile(args.test_dict_path):
316
+ raise FileNotFoundError(f"Test dict not found: {args.test_dict_path}")
317
+
318
+ with open(args.train_dict_path, "rb") as f:
319
+ data_dict_train = pickle.load(f)
320
+ with open(args.test_dict_path, "rb") as f:
321
+ data_dict_test = pickle.load(f)
322
+
323
+ trainset = UserData("registration", data_dict_train)
324
+ testset = UserData("registration", data_dict_test)
325
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.workers)
326
+ test_loader = DataLoader(testset, batch_size=5, shuffle=False, drop_last=False, num_workers=args.workers)
327
+
328
+ if not torch.cuda.is_available():
329
+ args.device = "cpu"
330
+ args.device = torch.device(args.device)
331
+
332
+ # feature extractor model
333
+ FEATURE_MODEL = feature_extractor
334
+ model = R3PMNet(feature_model=FEATURE_MODEL)
335
+ model = model.to(args.device)
336
+
337
+ if args.transfer_weights and os.path.isfile(args.transfer_weights):
338
+ feat_model_dict = _torch_load(args.transfer_weights, args.device)
339
+ model.feat_extractor.load_state_dict(feat_model_dict)
340
+
341
+ checkpoint = None
342
+ if args.resume:
343
+ assert os.path.isfile(args.resume)
344
+ checkpoint = _torch_load(args.resume, args.device)
345
+ args.start_epoch = checkpoint["epoch"]
346
+ model.load_state_dict(checkpoint["model"])
347
+
348
+ if args.pretrained:
349
+ assert os.path.isfile(args.pretrained)
350
+ try:
351
+ model.load_state_dict(_torch_load(args.pretrained, "cpu"))
352
+ except RuntimeError:
353
+ model_data = _torch_load(args.pretrained, "cpu")
354
+ state_dict = model_data["state_dict"]
355
+ model.load_state_dict(state_dict)
356
+ model.to(args.device)
357
+
358
+ Path(args.save_dir).mkdir(parents=True, exist_ok=True)
359
+
360
+ if args.eval:
361
+ test(args, model, test_loader, textio)
362
+ else:
363
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
364
+
365
+ if __name__ == "__main__":
366
+ main()
thirdparty/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Namespace for vendored thirdparty.learning3d
thirdparty/learning3d/data_utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .dataloaders import ModelNet40Data
2
+ from .dataloaders import ClassificationData, RegistrationData, SegmentationData, FlowData, SceneflowDataset
3
+ from .dataloaders import download_modelnet40, deg_to_rad, create_random_transform
4
+ from .user_data import UserData
thirdparty/learning3d/data_utils/dataloaders.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ import os
8
+ import h5py
9
+ import subprocess
10
+ import shlex
11
+ import json
12
+ import glob
13
+ from .. ops import transform_functions, se3
14
+ from sklearn.neighbors import NearestNeighbors
15
+ from scipy.spatial.distance import minkowski
16
+ from scipy.spatial import cKDTree
17
+ from torch.utils.data import Dataset
18
+
19
+ def download_modelnet40():
20
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
22
+ if not os.path.exists(DATA_DIR):
23
+ os.mkdir(DATA_DIR)
24
+ if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
25
+ www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
26
+ zipfile = os.path.basename(www)
27
+ os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile))
28
+ os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
29
+ os.system('rm %s' % (zipfile))
30
+
31
+ def load_data(train, use_normals):
32
+ if train: partition = 'train'
33
+ else: partition = 'test'
34
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
35
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
36
+ all_data = []
37
+ all_label = []
38
+ for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)):
39
+ f = h5py.File(h5_name)
40
+ if use_normals: data = np.concatenate([f['data'][:], f['normal'][:]], axis=-1).astype('float32')
41
+ else: data = f['data'][:].astype('float32')
42
+ label = f['label'][:].astype('int64')
43
+ f.close()
44
+ all_data.append(data)
45
+ all_label.append(label)
46
+ all_data = np.concatenate(all_data, axis=0)
47
+ all_label = np.concatenate(all_label, axis=0)
48
+ return all_data, all_label
49
+
50
+ def deg_to_rad(deg):
51
+ return np.pi / 180 * deg
52
+
53
+ def create_random_transform(dtype, max_rotation_deg, max_translation):
54
+ max_rotation = deg_to_rad(max_rotation_deg)
55
+ rot = np.random.uniform(-max_rotation, max_rotation, [1, 3])
56
+ trans = np.random.uniform(-max_translation, max_translation, [1, 3])
57
+ quat = transform_functions.euler_to_quaternion(rot, "xyz")
58
+
59
+ vec = np.concatenate([quat, trans], axis=1)
60
+ vec = torch.tensor(vec, dtype=dtype)
61
+ return vec
62
+
63
+ def jitter_pointcloud(pointcloud, sigma=0.04, clip=0.05):
64
+ # N, C = pointcloud.shape
65
+ sigma = 0.04*np.random.random_sample()
66
+ pointcloud += torch.empty(pointcloud.shape).normal_(mean=0, std=sigma).clamp(-clip, clip)
67
+ return pointcloud
68
+
69
+ def farthest_subsample_points(pointcloud1, num_subsampled_points=768):
70
+ pointcloud1 = pointcloud1
71
+ num_points = pointcloud1.shape[0]
72
+ nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto',
73
+ metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3])
74
+ random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
75
+ idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,))
76
+ gt_mask = torch.zeros(num_points).scatter_(0, torch.tensor(idx1), 1)
77
+ return pointcloud1[idx1, :], gt_mask
78
+
79
+ def uniform_2_sphere(num: int = None):
80
+ """Uniform sampling on a 2-sphere
81
+
82
+ Source: https://gist.github.com/andrewbolster/10274979
83
+
84
+ Args:
85
+ num: Number of vectors to sample (or None if single)
86
+
87
+ Returns:
88
+ Random Vector (np.ndarray) of size (num, 3) with norm 1.
89
+ If num is None returned value will have size (3,)
90
+
91
+ """
92
+ if num is not None:
93
+ phi = np.random.uniform(0.0, 2 * np.pi, num)
94
+ cos_theta = np.random.uniform(-1.0, 1.0, num)
95
+ else:
96
+ phi = np.random.uniform(0.0, 2 * np.pi)
97
+ cos_theta = np.random.uniform(-1.0, 1.0)
98
+
99
+ theta = np.arccos(cos_theta)
100
+ x = np.sin(theta) * np.cos(phi)
101
+ y = np.sin(theta) * np.sin(phi)
102
+ z = np.cos(theta)
103
+
104
+ return np.stack((x, y, z), axis=-1)
105
+
106
+ def planar_crop(points, p_keep= 0.7):
107
+ p_keep = np.array(p_keep, dtype=np.float32)
108
+
109
+ rand_xyz = uniform_2_sphere()
110
+ pts = points.numpy()
111
+ centroid = np.mean(pts[:, :3], axis=0)
112
+ points_centered = pts[:, :3] - centroid
113
+
114
+ dist_from_plane = np.dot(points_centered, rand_xyz)
115
+
116
+ mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100)
117
+ idx_x = torch.Tensor(np.nonzero(mask))
118
+
119
+ return torch.Tensor(pts[mask, :3]), idx_x
120
+
121
+ def knn_idx(pts, k):
122
+ kdt = cKDTree(pts)
123
+ _, idx = kdt.query(pts, k=k+1)
124
+ return idx[:, 1:]
125
+
126
+ def get_rri(pts, k):
127
+ # pts: N x 3, original points
128
+ # q: N x K x 3, nearest neighbors
129
+ q = pts[knn_idx(pts, k)]
130
+ p = np.repeat(pts[:, None], k, axis=1)
131
+ # rp, rq: N x K x 1, norms
132
+ rp = np.linalg.norm(p, axis=-1, keepdims=True)
133
+ rq = np.linalg.norm(q, axis=-1, keepdims=True)
134
+ pn = p / rp
135
+ qn = q / rq
136
+ dot = np.sum(pn * qn, -1, keepdims=True)
137
+ # theta: N x K x 1, angles
138
+ theta = np.arccos(np.clip(dot, -1, 1))
139
+ T_q = q - dot * p
140
+ sin_psi = np.sum(np.cross(T_q[:, None], T_q[:, :, None]) * pn[:, None], -1)
141
+ cos_psi = np.sum(T_q[:, None] * T_q[:, :, None], -1)
142
+ psi = np.arctan2(sin_psi, cos_psi) % (2*np.pi)
143
+ idx = np.argpartition(psi, 1)[:, :, 1:2]
144
+ # phi: N x K x 1, projection angles
145
+ phi = np.take_along_axis(psi, idx, axis=-1)
146
+ feat = np.concatenate([rp, rq, theta, phi], axis=-1)
147
+ return feat.reshape(-1, k * 4)
148
+
149
+ def get_rri_cuda(pts, k, npts_per_block=1):
150
+ try:
151
+ import pycuda.autoinit
152
+ from pycuda import gpuarray
153
+ from pycuda.compiler import SourceModule
154
+ except Exception as e:
155
+ print("Error raised in pycuda modules! pycuda only works with GPU, ", e)
156
+ raise
157
+
158
+ mod_rri = SourceModule(open('rri.cu').read() % (k, npts_per_block))
159
+ rri_cuda = mod_rri.get_function('get_rri_feature')
160
+
161
+ N = len(pts)
162
+ pts_gpu = gpuarray.to_gpu(pts.astype(np.float32).ravel())
163
+ k_idx = knn_idx(pts, k)
164
+ k_idx_gpu = gpuarray.to_gpu(k_idx.astype(np.int32).ravel())
165
+ feat_gpu = gpuarray.GPUArray((N * k * 4,), np.float32)
166
+
167
+ rri_cuda(pts_gpu, np.int32(N), k_idx_gpu, feat_gpu,
168
+ grid=(((N-1) // npts_per_block)+1, 1),
169
+ block=(npts_per_block, k, 1))
170
+
171
+ feat = feat_gpu.get().reshape(N, k * 4).astype(np.float32)
172
+ return feat
173
+
174
+
175
+ class UnknownDataTypeError(Exception):
176
+ def __init__(self, *args):
177
+ if args: self.message = args[0]
178
+ else: self.message = 'Datatype not understood for dataset.'
179
+
180
+ def __str__(self):
181
+ return self.message
182
+
183
+
184
+ class ModelNet40Data(Dataset):
185
+ def __init__(
186
+ self,
187
+ train=True,
188
+ num_points=1024,
189
+ download=True,
190
+ randomize_data=False,
191
+ use_normals=False
192
+ ):
193
+ super(ModelNet40Data, self).__init__()
194
+ if download: download_modelnet40()
195
+ self.data, self.labels = load_data(train, use_normals)
196
+ if not train: self.shapes = self.read_classes_ModelNet40()
197
+ self.num_points = num_points
198
+ self.randomize_data = randomize_data
199
+
200
+ def __getitem__(self, idx):
201
+ if self.randomize_data: current_points = self.randomize(idx)
202
+ else: current_points = self.data[idx].copy()
203
+
204
+ current_points = torch.from_numpy(current_points[:self.num_points, :]).float()
205
+ label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
206
+
207
+ return current_points, label
208
+
209
+ def __len__(self):
210
+ return self.data.shape[0]
211
+
212
+ def randomize(self, idx):
213
+ pt_idxs = np.arange(0, self.num_points)
214
+ np.random.shuffle(pt_idxs)
215
+ return self.data[idx, pt_idxs].copy()
216
+
217
+ def get_shape(self, label):
218
+ return self.shapes[label]
219
+
220
+ def read_classes_ModelNet40(self):
221
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
222
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
223
+ file = open(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'shape_names.txt'), 'r')
224
+ shape_names = file.read()
225
+ shape_names = np.array(shape_names.split('\n')[:-1])
226
+ return shape_names
227
+
228
+
229
+ class ClassificationData(Dataset):
230
+ def __init__(self, data_class=ModelNet40Data()):
231
+ super(ClassificationData, self).__init__()
232
+ self.set_class(data_class)
233
+
234
+ def __len__(self):
235
+ return len(self.data_class)
236
+
237
+ def set_class(self, data_class):
238
+ self.data_class = data_class
239
+
240
+ def get_shape(self, label):
241
+ try:
242
+ return self.data_class.get_shape(label)
243
+ except:
244
+ return -1
245
+
246
+ def __getitem__(self, index):
247
+ return self.data_class[index]
248
+
249
+
250
+ class RegistrationData(Dataset):
251
+ def __init__(self, algorithm, data_class=ModelNet40Data(), partial_source=False, partial_template=False, noise=False, additional_params={}):
252
+ super(RegistrationData, self).__init__()
253
+ available_algorithms = ['PCRNet', 'PointNetLK', 'DCP', 'PRNet', 'iPCRNet', 'RPMNet', 'DeepGMR']
254
+ if algorithm in available_algorithms: self.algorithm = algorithm
255
+ else: raise Exception("Algorithm not available for registration.")
256
+
257
+ self.set_class(data_class)
258
+ self.partial_template = partial_template
259
+ self.partial_source = partial_source
260
+ self.noise = noise
261
+ self.additional_params = additional_params
262
+ self.use_rri = False
263
+
264
+ if self.algorithm == 'PCRNet' or self.algorithm == 'iPCRNet':
265
+ from .. ops.transform_functions import PCRNetTransform
266
+ self.transforms = PCRNetTransform(len(data_class), angle_range=45, translation_range=1)
267
+ if self.algorithm == 'PointNetLK':
268
+ from .. ops.transform_functions import PNLKTransform
269
+ self.transforms = PNLKTransform(0.8, True)
270
+ if self.algorithm == 'RPMNet':
271
+ from .. ops.transform_functions import RPMNetTransform
272
+ self.transforms = RPMNetTransform(0.8, True)
273
+ if self.algorithm == 'DCP' or self.algorithm == 'PRNet':
274
+ from .. ops.transform_functions import DCPTransform
275
+ self.transforms = DCPTransform(angle_range=45, translation_range=1)
276
+ if self.algorithm == 'DeepGMR':
277
+ self.get_rri = get_rri_cuda if torch.cuda.is_available() else get_rri
278
+ from .. ops.transform_functions import DeepGMRTransform
279
+ self.transforms = DeepGMRTransform(angle_range=90, translation_range=1)
280
+ if 'nearest_neighbors' in self.additional_params.keys() and self.additional_params['nearest_neighbors'] > 0:
281
+ self.use_rri = True
282
+ self.nearest_neighbors = self.additional_params['nearest_neighbors']
283
+
284
+ def __len__(self):
285
+ return len(self.data_class)
286
+
287
+ def set_class(self, data_class):
288
+ self.data_class = data_class
289
+
290
+ def __getitem__(self, index):
291
+ template, label = self.data_class[index]
292
+ self.transforms.index = index # for fixed transformations in PCRNet.
293
+ source = self.transforms(template)
294
+
295
+ # Check for Partial Data.
296
+ if self.additional_params.get('partial_point_cloud_method', None) == 'planar_crop':
297
+ source, gt_idx_source = planar_crop(source)
298
+ template, gt_idx_template = planar_crop(template)
299
+ intersect_mask, intersect_x, intersect_y = np.intersect1d(gt_idx_source, gt_idx_template, return_indices=True)
300
+
301
+ self.template_mask = torch.zeros(template.shape[0])
302
+ self.source_mask = torch.zeros(source.shape[0])
303
+ self.template_mask[intersect_y] = 1
304
+ self.source_mask[intersect_x] = 1
305
+ else:
306
+ if self.partial_source: source, self.source_mask = farthest_subsample_points(source)
307
+ if self.partial_template: template, self.template_mask = farthest_subsample_points(template)
308
+
309
+
310
+
311
+ # Check for Noise in Source Data.
312
+ if self.noise: source = jitter_pointcloud(source)
313
+
314
+ if self.use_rri:
315
+ template, source = template.numpy(), source.numpy()
316
+ template = np.concatenate([template, self.get_rri(template - template.mean(axis=0), self.nearest_neighbors)], axis=1)
317
+ source = np.concatenate([source, self.get_rri(source - source.mean(axis=0), self.nearest_neighbors)], axis=1)
318
+ template, source = torch.tensor(template).float(), torch.tensor(source).float()
319
+
320
+ igt = self.transforms.igt
321
+
322
+ if self.additional_params.get('use_masknet', False):
323
+ if self.partial_source and self.partial_template:
324
+ return template, source, igt, self.template_mask, self.source_mask
325
+ elif self.partial_source:
326
+ return template, source, igt, self.source_mask
327
+ elif self.partial_template:
328
+ return template, source, igt, self.template_mask
329
+ else:
330
+ return template, source, igt
331
+
332
+
333
+ class SegmentationData(Dataset):
334
+ def __init__(self):
335
+ super(SegmentationData, self).__init__()
336
+
337
+ def __len__(self):
338
+ pass
339
+
340
+ def __getitem__(self, index):
341
+ pass
342
+
343
+
344
+ class FlowData(Dataset):
345
+ def __init__(self):
346
+ super(FlowData, self).__init__()
347
+ self.pc1, self.pc2, self.flow = self.read_data()
348
+
349
+ def __len__(self):
350
+ if isinstance(self.pc1, np.ndarray):
351
+ return self.pc1.shape[0]
352
+ elif isinstance(self.pc1, list):
353
+ return len(self.pc1)
354
+ else:
355
+ raise UnknownDataTypeError
356
+
357
+ def read_data(self):
358
+ pass
359
+
360
+ def __getitem__(self, index):
361
+ return self.pc1[index], self.pc2[index], self.flow[index]
362
+
363
+
364
+ class SceneflowDataset(Dataset):
365
+ def __init__(self, npoints=1024, root='', partition='train'):
366
+ if root == '':
367
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
368
+ DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data')
369
+ root = os.path.join(DATA_DIR, 'data_processed_maxcut_35_20k_2k_8192')
370
+ if not os.path.exists(root):
371
+ print("To download dataset, click here: https://drive.google.com/file/d/1CMaxdt-Tg1Wct8v8eGNwuT7qRSIyJPY-/view")
372
+ exit()
373
+ else:
374
+ print("SceneflowDataset Found Successfully!")
375
+
376
+ self.npoints = npoints
377
+ self.partition = partition
378
+ self.root = root
379
+ if self.partition=='train':
380
+ self.datapath = glob.glob(os.path.join(self.root, 'TRAIN*.npz'))
381
+ else:
382
+ self.datapath = glob.glob(os.path.join(self.root, 'TEST*.npz'))
383
+ self.cache = {}
384
+ self.cache_size = 30000
385
+
386
+ ###### deal with one bad datapoint with nan value
387
+ self.datapath = [d for d in self.datapath if 'TRAIN_C_0140_left_0006-0' not in d]
388
+ ######
389
+ print(self.partition, ': ',len(self.datapath))
390
+
391
+ def __getitem__(self, index):
392
+ if index in self.cache:
393
+ pos1, pos2, color1, color2, flow, mask1 = self.cache[index]
394
+ else:
395
+ fn = self.datapath[index]
396
+ with open(fn, 'rb') as fp:
397
+ data = np.load(fp)
398
+ pos1 = data['points1'].astype('float32')
399
+ pos2 = data['points2'].astype('float32')
400
+ color1 = data['color1'].astype('float32')
401
+ color2 = data['color2'].astype('float32')
402
+ flow = data['flow'].astype('float32')
403
+ mask1 = data['valid_mask1']
404
+
405
+ if len(self.cache) < self.cache_size:
406
+ self.cache[index] = (pos1, pos2, color1, color2, flow, mask1)
407
+
408
+ if self.partition == 'train':
409
+ n1 = pos1.shape[0]
410
+ sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
411
+ n2 = pos2.shape[0]
412
+ sample_idx2 = np.random.choice(n2, self.npoints, replace=False)
413
+
414
+ pos1 = pos1[sample_idx1, :]
415
+ pos2 = pos2[sample_idx2, :]
416
+ color1 = color1[sample_idx1, :]
417
+ color2 = color2[sample_idx2, :]
418
+ flow = flow[sample_idx1, :]
419
+ mask1 = mask1[sample_idx1]
420
+ else:
421
+ pos1 = pos1[:self.npoints, :]
422
+ pos2 = pos2[:self.npoints, :]
423
+ color1 = color1[:self.npoints, :]
424
+ color2 = color2[:self.npoints, :]
425
+ flow = flow[:self.npoints, :]
426
+ mask1 = mask1[:self.npoints]
427
+
428
+ pos1_center = np.mean(pos1, 0)
429
+ pos1 -= pos1_center
430
+ pos2 -= pos1_center
431
+
432
+ return pos1, pos2, color1, color2, flow, mask1
433
+
434
+ def __len__(self):
435
+ return len(self.datapath)
436
+
437
+
438
+ if __name__ == '__main__':
439
+ class Data():
440
+ def __init__(self):
441
+ super(Data, self).__init__()
442
+ self.data, self.label = self.read_data()
443
+
444
+ def read_data(self):
445
+ return [4,5,6], [4,5,6]
446
+
447
+ def __len__(self):
448
+ return len(self.data)
449
+
450
+ def __getitem__(self, idx):
451
+ return self.data[idx], self.label[idx]
452
+
453
+ cd = RegistrationData('abc')
454
+ import ipdb; ipdb.set_trace()
thirdparty/learning3d/data_utils/user_data.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ class ClassificationData:
6
+ def __init__(self, data_dict):
7
+ self.data_dict = data_dict
8
+ self.pcs = self.find_attribute('pcs')
9
+ self.labels = self.find_attribute('labels')
10
+ self.check_data()
11
+
12
+ def find_attribute(self, attribute):
13
+ try:
14
+ attribute_data = self.data_dict[attribute]
15
+ except:
16
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
17
+ return attribute_data
18
+
19
+ def check_data(self):
20
+ assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
21
+ assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
22
+
23
+ if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
24
+ if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
25
+
26
+ assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
27
+
28
+
29
+ def __len__(self):
30
+ return self.pcs.shape[0]
31
+
32
+ def __getitem__(self, index):
33
+ return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
34
+
35
+
36
+ class RegistrationData:
37
+ def __init__(self, data_dict):
38
+ self.data_dict = data_dict
39
+ self.template = self.find_attribute('template')
40
+ self.source = self.find_attribute('source')
41
+ self.transformation = self.find_attribute('transformation')
42
+ self.check_data()
43
+
44
+ def find_attribute(self, attribute):
45
+ try:
46
+ attribute_data = self.data[attribute]
47
+ except:
48
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
49
+ return attribute_data
50
+
51
+ def check_data(self):
52
+ assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
53
+ assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
54
+ assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
55
+
56
+ if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
57
+ if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
58
+ if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
59
+
60
+ assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
61
+ assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
62
+
63
+ def __len__(self):
64
+ return self.template.shape[0]
65
+
66
+ def __getitem__(self, index):
67
+ return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
68
+
69
+
70
+ class FlowData:
71
+ def __init__(self, data_dict):
72
+ self.data_dict = data_dict
73
+ self.frame1 = self.find_attribute('frame1')
74
+ self.frame2 = self.find_attribute('frame2')
75
+ self.flow = self.find_attribute('flow')
76
+ self.check_data()
77
+
78
+ def find_attribute(self, attribute):
79
+ try:
80
+ attribute_data = self.data[attribute]
81
+ except:
82
+ print("Given data directory has no key attribute \"{}\"".format(attribute))
83
+ return attribute_data
84
+
85
+ def check_data(self):
86
+ assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
87
+ assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
88
+ assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
89
+
90
+ if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
91
+ if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
92
+ if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
93
+
94
+ assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
95
+ assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
96
+
97
+ def __len__(self):
98
+ return self.frame1.shape[0]
99
+
100
+ def __getitem__(self, index):
101
+ return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
102
+
103
+
104
+ class UserData:
105
+ def __init__(self, application, data_dict):
106
+ self.application = application
107
+
108
+ if self.application == 'classification':
109
+ self.data_class = ClassificationData(data_dict)
110
+ elif self.application == 'registration':
111
+ self.data_class = RegistrationData(data_dict)
112
+ elif self.application == 'flow_estimation':
113
+ self.data_class = FlowData(data_dict)
114
+
115
+ def __len__(self):
116
+ return len(self.data_class)
117
+
118
+ def __getitem__(self, index):
119
+ return self.data_class[index]
thirdparty/learning3d/examples/test_curvenet.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import CurveNet
22
+ from learning3d.data_utils import ClassificationData, ModelNet40Data
23
+
24
+ def display_open3d(template):
25
+ template_ = o3d.geometry.PointCloud()
26
+ template_.points = o3d.utility.Vector3dVector(template)
27
+ # template_.paint_uniform_color([1, 0, 0])
28
+ o3d.visualization.draw_geometries([template_])
29
+
30
+ def test_one_epoch(device, model, test_loader, testset):
31
+ model.eval()
32
+ test_loss = 0.0
33
+ pred = 0.0
34
+ count = 0
35
+ for i, data in enumerate(tqdm(test_loader)):
36
+ points, target = data
37
+ target = target[:,0]
38
+
39
+ points = points.to(device)
40
+ target = target.to(device)
41
+
42
+ output = model(points)
43
+ loss_val = torch.nn.functional.nll_loss(
44
+ torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
45
+ print("Ground Truth Label: ", testset.get_shape(target[0].item()))
46
+ print("Predicted Label: ", testset.get_shape(torch.argmax(output[0]).item()))
47
+ display_open3d(points.detach().cpu().numpy()[0])
48
+
49
+ test_loss += loss_val.item()
50
+ count += output.size(0)
51
+
52
+ _, pred1 = output.max(dim=1)
53
+ ag = (pred1 == target)
54
+ am = ag.sum()
55
+ pred += am.item()
56
+
57
+ test_loss = float(test_loss)/count
58
+ accuracy = float(pred)/count
59
+ return test_loss, accuracy
60
+
61
+ def test(args, model, test_loader, testset):
62
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader, testset)
63
+ print("Accuracy: ", test_accuracy*100)
64
+
65
+ def options():
66
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
67
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
68
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
69
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
70
+
71
+ # settings for input data
72
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
73
+ metavar='DATASET', help='dataset type (default: modelnet)')
74
+ parser.add_argument('--num_points', default=1024, type=int,
75
+ metavar='N', help='points in point-cloud (default: 1024)')
76
+
77
+ # settings for CurveNet
78
+ parser.add_argument('-j', '--workers', default=4, type=int,
79
+ metavar='N', help='number of data loading workers (default: 4)')
80
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
81
+ metavar='N', help='mini-batch size (default: 32)')
82
+ parser.add_argument('--num_classes', default=40, type=int,
83
+ metavar='K', help='number of classes to be predicted')
84
+
85
+ # settings for on training
86
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_curvenet/models/model.t7', type=str,
87
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
88
+ parser.add_argument('--device', default='cuda:0', type=str,
89
+ metavar='DEVICE', help='use CUDA if available')
90
+
91
+ args = parser.parse_args()
92
+ return args
93
+
94
+ def main():
95
+ args = options()
96
+ args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
97
+
98
+ testset = ClassificationData(ModelNet40Data(train=False))
99
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
100
+
101
+ if not torch.cuda.is_available():
102
+ args.device = 'cpu'
103
+ args.device = torch.device(args.device)
104
+
105
+ # Create PointNet Model.
106
+ model = CurveNet(num_classes=args.num_classes, k=20)
107
+
108
+ if args.pretrained:
109
+ assert os.path.isfile(args.pretrained)
110
+ weights = torch.load(args.pretrained, map_location='cpu')
111
+ weights = {k[7:]: v for k, v in weights.items()}
112
+ model.load_state_dict(weights)
113
+ model.to(args.device)
114
+
115
+ test(args, model, test_loader, testset)
116
+
117
+ if __name__ == '__main__':
118
+ main()
thirdparty/learning3d/examples/test_dcp.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import DGCNN, DCP
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def get_transformations(igt):
25
+ R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
26
+ translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
27
+ R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
28
+ translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
29
+ return R_ab, translation_ab, R_ba, translation_ba
30
+
31
+ def display_open3d(template, source, transformed_source):
32
+ template_ = o3d.geometry.PointCloud()
33
+ source_ = o3d.geometry.PointCloud()
34
+ transformed_source_ = o3d.geometry.PointCloud()
35
+ template_.points = o3d.utility.Vector3dVector(template)
36
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
37
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
38
+ template_.paint_uniform_color([1, 0, 0])
39
+ source_.paint_uniform_color([0, 1, 0])
40
+ transformed_source_.paint_uniform_color([0, 0, 1])
41
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
42
+
43
+ def test_one_epoch(device, model, test_loader):
44
+ model.eval()
45
+ test_loss = 0.0
46
+ pred = 0.0
47
+ count = 0
48
+ for i, data in enumerate(tqdm(test_loader)):
49
+ template, source, igt = data
50
+ transformations = get_transformations(igt)
51
+ transformations = [t.to(device) for t in transformations]
52
+ R_ab, translation_ab, R_ba, translation_ba = transformations
53
+
54
+ template = template.to(device)
55
+ source = source.to(device)
56
+ igt = igt.to(device)
57
+
58
+ output = model(template, source)
59
+ display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
60
+
61
+ identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
62
+ loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
63
+ + torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
64
+
65
+ cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
66
+ + torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
67
+ loss_val = loss_val + cycle_loss * 0.1
68
+
69
+ test_loss += loss_val.item()
70
+ count += 1
71
+
72
+ test_loss = float(test_loss)/count
73
+ return test_loss
74
+
75
+ def test(args, model, test_loader):
76
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
77
+
78
+ def options():
79
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
80
+ parser.add_argument('--exp_name', type=str, default='exp_ipcrnet', metavar='N',
81
+ help='Name of the experiment')
82
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
83
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
84
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
85
+
86
+ # settings for input data
87
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
88
+ metavar='DATASET', help='dataset type (default: modelnet)')
89
+ parser.add_argument('--num_points', default=1024, type=int,
90
+ metavar='N', help='points in point-cloud (default: 1024)')
91
+
92
+ # settings for PointNet
93
+ parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
94
+ help='train pointnet (default: tune)')
95
+ parser.add_argument('--emb_dims', default=512, type=int,
96
+ metavar='K', help='dim. of the feature vector (default: 1024)')
97
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
98
+ help='symmetric function (default: max)')
99
+
100
+ # settings for on training
101
+ parser.add_argument('-j', '--workers', default=4, type=int,
102
+ metavar='N', help='number of data loading workers (default: 4)')
103
+ parser.add_argument('-b', '--batch_size', default=2, type=int,
104
+ metavar='N', help='mini-batch size (default: 32)')
105
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_dcp/models/best_model.t7', type=str,
106
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
107
+ parser.add_argument('--device', default='cuda:0', type=str,
108
+ metavar='DEVICE', help='use CUDA if available')
109
+
110
+ args = parser.parse_args()
111
+ return args
112
+
113
+ def main():
114
+ args = options()
115
+ torch.backends.cudnn.deterministic = True
116
+
117
+ trainset = RegistrationData('DCP', ModelNet40Data(train=True))
118
+ testset = RegistrationData('DCP', ModelNet40Data(train=False))
119
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
120
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
121
+
122
+ if not torch.cuda.is_available():
123
+ args.device = 'cpu'
124
+ args.device = torch.device(args.device)
125
+
126
+ # Create PointNet Model.
127
+ dgcnn = DGCNN(emb_dims=args.emb_dims)
128
+ model = DCP(feature_model=dgcnn, cycle=True)
129
+ model = model.to(args.device)
130
+
131
+ if args.pretrained:
132
+ assert os.path.isfile(args.pretrained)
133
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'), strict=False)
134
+ model.to(args.device)
135
+
136
+ test(args, model, test_loader)
137
+
138
+ if __name__ == '__main__':
139
+ main()
thirdparty/learning3d/examples/test_deepgmr.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import DeepGMR
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def display_open3d(template, source, transformed_source):
25
+ template_ = o3d.geometry.PointCloud()
26
+ source_ = o3d.geometry.PointCloud()
27
+ transformed_source_ = o3d.geometry.PointCloud()
28
+ template_.points = o3d.utility.Vector3dVector(template)
29
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
30
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
31
+ template_.paint_uniform_color([1, 0, 0])
32
+ source_.paint_uniform_color([0, 1, 0])
33
+ transformed_source_.paint_uniform_color([0, 0, 1])
34
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
35
+
36
+ def rotation_error(R, R_gt):
37
+ cos_theta = (torch.einsum('bij,bij->b', R, R_gt) - 1) / 2
38
+ cos_theta = torch.clamp(cos_theta, -1, 1)
39
+ return torch.acos(cos_theta) * 180 / math.pi
40
+
41
+ def translation_error(t, t_gt):
42
+ return torch.norm(t - t_gt, dim=1)
43
+
44
+ def rmse(pts, T, T_gt):
45
+ pts_pred = pts @ T[:, :3, :3].transpose(1, 2) + T[:, :3, 3].unsqueeze(1)
46
+ pts_gt = pts @ T_gt[:, :3, :3].transpose(1, 2) + T_gt[:, :3, 3].unsqueeze(1)
47
+ return torch.norm(pts_pred - pts_gt, dim=2).mean(dim=1)
48
+
49
+ def test_one_epoch(device, model, test_loader):
50
+ model.eval()
51
+ test_loss = 0.0
52
+ pred = 0.0
53
+ count = 0
54
+ rotation_errors, translation_errors, rmses = [], [], []
55
+
56
+ for i, data in enumerate(tqdm(test_loader)):
57
+ template, source, igt = data
58
+
59
+ template = template.to(device)
60
+ source = source.to(device)
61
+ igt = igt.to(device)
62
+
63
+ output = model(template, source)
64
+ display_open3d(template.detach().cpu().numpy()[0, :, :3], source.detach().cpu().numpy()[0, :, :3], output['transformed_source'].detach().cpu().numpy()[0])
65
+
66
+ eye = torch.eye(4).expand_as(igt).to(igt.device)
67
+ mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
68
+ mse2 = F.mse_loss(output['est_T'] @ igt, eye)
69
+ loss = mse1 + mse2
70
+
71
+ r_err = rotation_error(est_T_inverse[:, :3, :3], igt[:, :3, :3])
72
+ t_err = translation_error(est_T_inverse[:, :3, 3], igt[:, :3, 3])
73
+ rmse_val = rmse(template[:, :100], est_T_inverse, igt)
74
+ rotation_errors.append(r_err)
75
+ translation_errors.append(t_err)
76
+ rmses.append(rmse_val)
77
+
78
+ test_loss += loss_val.item()
79
+ count += 1
80
+
81
+ test_loss = float(test_loss)/count
82
+ print("Mean rotation error: {}, Mean translation error: {} and Mean RMSE: {}".format(np.mean(rotation_errors), np.mean(translation_errors), np.mean(rmses)))
83
+ return test_loss
84
+
85
+ def test(args, model, test_loader):
86
+ test_loss = test_one_epoch(args.device, model, test_loader)
87
+
88
+ def options():
89
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
90
+ parser.add_argument('--exp_name', type=str, default='exp_deepgmr', metavar='N',
91
+ help='Name of the experiment')
92
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
93
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
94
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
95
+
96
+ # settings for input data
97
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
98
+ metavar='DATASET', help='dataset type (default: modelnet)')
99
+ parser.add_argument('--num_points', default=1024, type=int,
100
+ metavar='N', help='points in point-cloud (default: 1024)')
101
+
102
+ parser.add_argument('--nearest_neighbors', default=20, type=int,
103
+ metavar='K', help='No of nearest neighbors to be estimated.')
104
+ parser.add_argument('--use_rri', default=True, type=bool,
105
+ help='Find nearest neighbors to estimate features from PointNet.')
106
+
107
+ # settings for on training
108
+ parser.add_argument('-j', '--workers', default=4, type=int,
109
+ metavar='N', help='number of data loading workers (default: 4)')
110
+ parser.add_argument('-b', '--batch_size', default=2, type=int,
111
+ metavar='N', help='mini-batch size (default: 32)')
112
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_deepgmr/models/best_model.pth', type=str,
113
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
114
+ parser.add_argument('--device', default='cuda:0', type=str,
115
+ metavar='DEVICE', help='use CUDA if available')
116
+
117
+ args = parser.parse_args()
118
+ return args
119
+
120
+ def main():
121
+ args = options()
122
+ torch.backends.cudnn.deterministic = True
123
+
124
+ trainset = RegistrationData('DeepGMR', ModelNet40Data(train=True))
125
+ testset = RegistrationData('DeepGMR', ModelNet40Data(train=False))
126
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
127
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
128
+
129
+ if not torch.cuda.is_available():
130
+ args.device = 'cpu'
131
+ args.device = torch.device(args.device)
132
+
133
+ model = DeepGMR(use_rri=args.use_rri, nearest_neighbors=args.nearest_neighbors)
134
+ model = model.to(args.device)
135
+
136
+ if args.pretrained:
137
+ assert os.path.isfile(args.pretrained)
138
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'), strict=False)
139
+ model.to(args.device)
140
+
141
+ test(args, model, test_loader)
142
+
143
+ if __name__ == '__main__':
144
+ main()
thirdparty/learning3d/examples/test_flownet.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ import open3d as o3d
6
+ import os
7
+ import gc
8
+ import argparse
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torch.optim.lr_scheduler import MultiStepLR
14
+ from learning3d.models import FlowNet3D
15
+ from learning3d.data_utils import SceneflowDataset
16
+ import numpy as np
17
+ from torch.utils.data import DataLoader
18
+ from tensorboardX import SummaryWriter
19
+ from tqdm import tqdm
20
+
21
+ def display_open3d(template, source, transformed_source):
22
+ template_ = o3d.geometry.PointCloud()
23
+ source_ = o3d.geometry.PointCloud()
24
+ transformed_source_ = o3d.geometry.PointCloud()
25
+ template_.points = o3d.utility.Vector3dVector(template)
26
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0.5,0.5]))
27
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
28
+ template_.paint_uniform_color([1, 0, 0])
29
+ source_.paint_uniform_color([0, 1, 0])
30
+ transformed_source_.paint_uniform_color([0, 0, 1])
31
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
32
+
33
+ def test_one_epoch(args, net, test_loader):
34
+ net.eval()
35
+
36
+ total_loss = 0
37
+ num_examples = 0
38
+ for i, data in enumerate(tqdm(test_loader)):
39
+ data = [d.to(args.device) for d in data]
40
+ pc1, pc2, color1, color2, flow, mask1 = data
41
+ pc1 = pc1.transpose(2,1).contiguous()
42
+ pc2 = pc2.transpose(2,1).contiguous()
43
+ color1 = color1.transpose(2,1).contiguous()
44
+ color2 = color2.transpose(2,1).contiguous()
45
+ flow = flow
46
+ mask1 = mask1.float()
47
+
48
+ batch_size = pc1.size(0)
49
+ num_examples += batch_size
50
+ flow_pred = net(pc1, pc2, color1, color2).permute(0,2,1)
51
+ loss_1 = torch.mean(mask1 * torch.sum((flow_pred - flow) * (flow_pred - flow), -1) / 2.0)
52
+
53
+ pc1, pc2 = pc1.permute(0,2,1), pc2.permute(0,2,1)
54
+ pc1_ = pc1 - flow_pred
55
+ print("Loss: ", loss_1)
56
+ display_open3d(pc1.detach().cpu().numpy()[0], pc2.detach().cpu().numpy()[0], pc1_.detach().cpu().numpy()[0])
57
+ total_loss += loss_1.item() * batch_size
58
+
59
+ return total_loss * 1.0 / num_examples
60
+
61
+
62
+ def test(args, net, test_loader):
63
+ test_loss = test_one_epoch(args, net, test_loader)
64
+
65
+ def main():
66
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
67
+ parser.add_argument('--model', type=str, default='flownet', metavar='N',
68
+ choices=['flownet'], help='Model to use, [flownet]')
69
+ parser.add_argument('--emb_dims', type=int, default=512, metavar='N',
70
+ help='Dimension of embeddings')
71
+ parser.add_argument('--num_points', type=int, default=2048,
72
+ help='Point Number [default: 2048]')
73
+ parser.add_argument('--test_batch_size', type=int, default=1, metavar='batch_size',
74
+ help='Size of batch)')
75
+
76
+ parser.add_argument('--gaussian_noise', type=bool, default=False, metavar='N',
77
+ help='Wheter to add gaussian noise')
78
+ parser.add_argument('--unseen', type=bool, default=False, metavar='N',
79
+ help='Whether to test on unseen category')
80
+ parser.add_argument('--dataset', type=str, default='SceneflowDataset',
81
+ choices=['SceneflowDataset'], metavar='N',
82
+ help='dataset to use')
83
+ parser.add_argument('--dataset_path', type=str, default='data_processed_maxcut_35_20k_2k_8192', metavar='N',
84
+ help='dataset to use')
85
+ parser.add_argument('--pretrained', type=str, default='learning3d/pretrained/exp_flownet/models/model.best.t7', metavar='N',
86
+ help='Pretrained model path')
87
+ parser.add_argument('--device', default='cuda:0', type=str,
88
+ metavar='DEVICE', help='use CUDA if available')
89
+
90
+ args = parser.parse_args()
91
+ if not torch.cuda.is_available():
92
+ args.device = torch.device('cpu')
93
+ else:
94
+ args.device = torch.device('cuda')
95
+
96
+ if args.dataset == 'SceneflowDataset':
97
+ test_loader = DataLoader(
98
+ SceneflowDataset(npoints=args.num_points, partition='test'),
99
+ batch_size=args.test_batch_size, shuffle=False, drop_last=False)
100
+ else:
101
+ raise Exception("not implemented")
102
+
103
+ net = FlowNet3D()
104
+ assert os.path.exists(args.pretrained), "Pretrained Model Doesn't Exists!"
105
+ net.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
106
+ net = net.to(args.device)
107
+
108
+ test(args, net, test_loader)
109
+ print('FINISH')
110
+
111
+
112
+ if __name__ == '__main__':
113
+ main()
thirdparty/learning3d/examples/test_masknet.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import MaskNet
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def pc2open3d(data):
25
+ if torch.is_tensor(data): data = data.detach().cpu().numpy()
26
+ if len(data.shape) == 2:
27
+ pc = o3d.geometry.PointCloud()
28
+ pc.points = o3d.utility.Vector3dVector(data)
29
+ return pc
30
+ else:
31
+ print("Error in the shape of data given to Open3D!, Shape is ", data.shape)
32
+
33
+ def display_results(template, source, masked_template):
34
+ template = pc2open3d(template)
35
+ source = pc2open3d(source)
36
+ masked_template = pc2open3d(masked_template)
37
+
38
+ template.paint_uniform_color([1, 0, 0])
39
+ source.paint_uniform_color([0, 1, 0])
40
+ masked_template.paint_uniform_color([0, 0, 1])
41
+
42
+ o3d.visualization.draw_geometries([template, source])
43
+ o3d.visualization.draw_geometries([masked_template, source])
44
+
45
+ def evaluate_metrics(TP, FP, FN, TN, gt_mask):
46
+ # TP, FP, FN, TN: True +ve, False +ve, False -ve, True -ve
47
+ # gt_mask: Ground Truth mask [Nt, 1]
48
+
49
+ accuracy = (TP + TN)/gt_mask.shape[1]
50
+ misclassification_rate = (FN + FP)/gt_mask.shape[1]
51
+ # Precision: (What portion of positive identifications are actually correct?)
52
+ precision = TP / (TP + FP)
53
+ # Recall: (What portion of actual positives are identified correctly?)
54
+ recall = TP / (TP + FN)
55
+
56
+ fscore = (2*precision*recall) / (precision + recall)
57
+ return accuracy, precision, recall, fscore
58
+
59
+ # Function used to evaluate the predicted mask with ground truth mask.
60
+ def evaluate_mask(gt_mask, predicted_mask, predicted_mask_idx):
61
+ # gt_mask: Ground Truth Mask [Nt, 1]
62
+ # predicted_mask: Mask predicted by network [Nt, 1]
63
+ # predicted_mask_idx: Point indices chosen by network [Ns, 1]
64
+
65
+ if torch.is_tensor(gt_mask): gt_mask = gt_mask.detach().cpu().numpy()
66
+ if torch.is_tensor(gt_mask): predicted_mask = predicted_mask.detach().cpu().numpy()
67
+ if torch.is_tensor(predicted_mask_idx): predicted_mask_idx = predicted_mask_idx.detach().cpu().numpy()
68
+ gt_mask, predicted_mask, predicted_mask_idx = gt_mask.reshape(1,-1), predicted_mask.reshape(1,-1), predicted_mask_idx.reshape(1,-1)
69
+
70
+ gt_idx = np.where(gt_mask == 1)[1].reshape(1,-1) # Find indices of points which are actually in source.
71
+
72
+ # TP + FP = number of source points.
73
+ TP = np.intersect1d(predicted_mask_idx[0], gt_idx[0]).shape[0] # is inliner and predicted as inlier (True Positive) (Find common indices in predicted_mask_idx, gt_idx)
74
+ FP = len([x for x in predicted_mask_idx[0] if x not in gt_idx]) # isn't inlier but predicted as inlier (False Positive)
75
+ FN = FP # is inlier but predicted as outlier (False Negative) (due to binary classification)
76
+ TN = gt_mask.shape[1] - gt_idx.shape[1] - FN # is outlier and predicted as outlier (True Negative)
77
+ return evaluate_metrics(TP, FP, FN, TN, gt_mask)
78
+
79
+ def test_one_epoch(args, model, test_loader):
80
+ model.eval()
81
+ test_loss = 0.0
82
+ pred = 0.0
83
+ count = 0
84
+ precision_list = []
85
+
86
+ for i, data in enumerate(tqdm(test_loader)):
87
+ template, source, igt, gt_mask = data
88
+
89
+ template = template.to(args.device)
90
+ source = source.to(args.device)
91
+ igt = igt.to(args.device) # [source] = [igt]*[template]
92
+ gt_mask = gt_mask.to(args.device)
93
+
94
+ masked_template, predicted_mask = model(template, source)
95
+
96
+ # Evaluate mask based on classification metrics.
97
+ accuracy, precision, recall, fscore = evaluate_mask(gt_mask, predicted_mask, predicted_mask_idx = model.mask_idx)
98
+ precision_list.append(precision)
99
+
100
+ # Different ways to visualize results.
101
+ display_results(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], masked_template.detach().cpu().numpy()[0])
102
+
103
+ print("Mean Precision: ", np.mean(precision_list))
104
+
105
+ def test(args, model, test_loader):
106
+ test_one_epoch(args, model, test_loader)
107
+
108
+ def options():
109
+ parser = argparse.ArgumentParser(description='MaskNet: A Fully-Convolutional Network For Inlier Estimation (Testing)')
110
+
111
+ # settings for input data
112
+ parser.add_argument('--num_points', default=1024, type=int,
113
+ metavar='N', help='points in point-cloud (default: 1024)')
114
+ parser.add_argument('--partial_source', default=True, type=bool,
115
+ help='create partial source point cloud in dataset.')
116
+ parser.add_argument('--noise', default=False, type=bool,
117
+ help='Add noise in source point clouds.')
118
+ parser.add_argument('--outliers', default=False, type=bool,
119
+ help='Add outliers to template point cloud.')
120
+
121
+ # settings for on testing
122
+ parser.add_argument('-j', '--workers', default=1, type=int,
123
+ metavar='N', help='number of data loading workers (default: 4)')
124
+ parser.add_argument('-b', '--test_batch_size', default=1, type=int,
125
+ metavar='N', help='test-mini-batch size (default: 1)')
126
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_masknet/models/best_model.t7', type=str,
127
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
128
+ parser.add_argument('--device', default='cuda:0', type=str,
129
+ metavar='DEVICE', help='use CUDA if available')
130
+ parser.add_argument('--unseen', default=False, type=bool,
131
+ help='Use first 20 categories for training and last 20 for testing')
132
+
133
+ args = parser.parse_args()
134
+ return args
135
+
136
+ def main():
137
+ args = options()
138
+ torch.backends.cudnn.deterministic = True
139
+
140
+ testset = RegistrationData('PointNetLK', ModelNet40Data(train=False, num_points=args.num_points),
141
+ partial_source=args.partial_source, noise=args.noise,
142
+ additional_params={'use_masknet': True})
143
+ test_loader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
144
+
145
+ if not torch.cuda.is_available():
146
+ args.device = 'cpu'
147
+ args.device = torch.device(args.device)
148
+
149
+ # Load Pretrained MaskNet.
150
+ model = MaskNet()
151
+ if args.pretrained:
152
+ assert os.path.isfile(args.pretrained)
153
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
154
+ model = model.to(args.device)
155
+
156
+ test(args, model, test_loader)
157
+
158
+ if __name__ == '__main__':
159
+ main()
thirdparty/learning3d/examples/test_masknet2.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+
12
+ # Only if the files are in example folder.
13
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14
+ if BASE_DIR[-8:] == 'examples':
15
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
16
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
17
+
18
+ from learning3d.models import MaskNet2
19
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
20
+
21
+ def pc2open3d(data):
22
+ if torch.is_tensor(data): data = data.detach().cpu().numpy()
23
+ if len(data.shape) == 2:
24
+ pc = o3d.geometry.PointCloud()
25
+ pc.points = o3d.utility.Vector3dVector(data)
26
+ return pc
27
+ else:
28
+ print("Error in the shape of data given to Open3D!, Shape is ", data.shape)
29
+
30
+ def display_results(template, source, masked_template, masked_source):
31
+ template = pc2open3d(template)
32
+ source = pc2open3d(source)
33
+ masked_template = pc2open3d(masked_template)
34
+ masked_source = pc2open3d(masked_source)
35
+
36
+ template.paint_uniform_color([1, 0, 0])
37
+ source.paint_uniform_color([0, 1, 0])
38
+ # masked_template.paint_uniform_color([0, 0, 1])
39
+ masked_template.paint_uniform_color([1, 0, 0])
40
+ masked_source.paint_uniform_color([0, 1, 0])
41
+
42
+ o3d.visualization.draw_geometries([template, source])
43
+ o3d.visualization.draw_geometries([masked_template, masked_source])
44
+
45
+ def evaluate_metrics(TP, FP, FN, TN, gt_mask):
46
+ # TP, FP, FN, TN: True +ve, False +ve, False -ve, True -ve
47
+ # gt_mask: Ground Truth mask [Nt, 1]
48
+
49
+ accuracy = (TP + TN)/gt_mask.shape[1]
50
+ misclassification_rate = (FN + FP)/gt_mask.shape[1]
51
+ # Precision: (What portion of positive identifications are actually correct?)
52
+ precision = TP / (TP + FP)
53
+ # Recall: (What portion of actual positives are identified correctly?)
54
+ recall = TP / (TP + FN)
55
+
56
+ fscore = (2*precision*recall) / (precision + recall)
57
+ return accuracy, precision, recall, fscore
58
+
59
+ # Function used to evaluate the predicted mask with ground truth mask.
60
+ def evaluate_mask(gt_mask, predicted_mask, predicted_mask_idx):
61
+ # gt_mask: Ground Truth Mask [Nt, 1]
62
+ # predicted_mask: Mask predicted by network [Nt, 1]
63
+ # predicted_mask_idx: Point indices chosen by network [Ns, 1]
64
+
65
+ if torch.is_tensor(gt_mask): gt_mask = gt_mask.detach().cpu().numpy()
66
+ if torch.is_tensor(gt_mask): predicted_mask = predicted_mask.detach().cpu().numpy()
67
+ if torch.is_tensor(predicted_mask_idx): predicted_mask_idx = predicted_mask_idx.detach().cpu().numpy()
68
+ gt_mask, predicted_mask, predicted_mask_idx = gt_mask.reshape(1,-1), predicted_mask.reshape(1,-1), predicted_mask_idx.reshape(1,-1)
69
+
70
+ gt_idx = np.where(gt_mask == 1)[1].reshape(1,-1) # Find indices of points which are actually in source.
71
+
72
+ # TP + FP = number of source points.
73
+ TP = np.intersect1d(predicted_mask_idx[0], gt_idx[0]).shape[0] # is inliner and predicted as inlier (True Positive) (Find common indices in predicted_mask_idx, gt_idx)
74
+ FP = len([x for x in predicted_mask_idx[0] if x not in gt_idx]) # isn't inlier but predicted as inlier (False Positive)
75
+ FN = FP # is inlier but predicted as outlier (False Negative) (due to binary classification)
76
+ TN = gt_mask.shape[1] - gt_idx.shape[1] - FN # is outlier and predicted as outlier (True Negative)
77
+ return evaluate_metrics(TP, FP, FN, TN, gt_mask)
78
+
79
+ def test_one_epoch(args, model, test_loader):
80
+ model.eval()
81
+ test_loss = 0.0
82
+ pred = 0.0
83
+ count = 0
84
+
85
+ for i, data in enumerate(tqdm(test_loader)):
86
+ template, source, igt, gt_template_mask, gt_source_mask = data
87
+
88
+ template = template.to(args.device)
89
+ source = source.to(args.device)
90
+ igt = igt.to(args.device) # [source] = [igt]*[template]
91
+ gt_template_mask = gt_template_mask.to(args.device)
92
+ gt_source_mask = gt_source_mask.to(args.device)
93
+
94
+ masked_template, masked_source, template_mask, source_mask = model(template, source)
95
+
96
+ # TODO: Implement evaluation strategy.
97
+ '''
98
+ Evaluate mask based on classification metrics.
99
+ accuracy, precision, recall, fscore = evaluate_mask(gt_template_mask, template_mask, predicted_mask_idx = model.mask_idx)
100
+ precision_list.append(precision)
101
+ '''
102
+
103
+ # Different ways to visualize results.
104
+ display_results(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], masked_template.detach().cpu().numpy()[0], masked_source.detach().cpu().numpy()[0])
105
+
106
+ def test(args, model, test_loader):
107
+ test_one_epoch(args, model, test_loader)
108
+
109
+ def options():
110
+ parser = argparse.ArgumentParser(description='MaskNet: A Fully-Convolutional Network For Inlier Estimation (Testing)')
111
+
112
+ # settings for input data
113
+ parser.add_argument('--num_points', default=1024, type=int,
114
+ metavar='N', help='points in point-cloud (default: 1024)')
115
+ parser.add_argument('--partial_source', default=True, type=bool,
116
+ help='create partial source point cloud in dataset.')
117
+ parser.add_argument('--partial_template', default=True, type=bool,
118
+ help='create partial source point cloud in dataset.')
119
+ parser.add_argument('--noise', default=False, type=bool,
120
+ help='Add noise in source point clouds.')
121
+ parser.add_argument('--outliers', default=False, type=bool,
122
+ help='Add outliers to template point cloud.')
123
+
124
+ # settings for on testing
125
+ parser.add_argument('-j', '--workers', default=1, type=int,
126
+ metavar='N', help='number of data loading workers (default: 4)')
127
+ parser.add_argument('-b', '--test_batch_size', default=1, type=int,
128
+ metavar='N', help='test-mini-batch size (default: 1)')
129
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_masknet2/models/best_model_0.7.t7', type=str,
130
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
131
+ parser.add_argument('--device', default='cuda:0', type=str,
132
+ metavar='DEVICE', help='use CUDA if available')
133
+ parser.add_argument('--unseen', default=False, type=bool,
134
+ help='Use first 20 categories for training and last 20 for testing')
135
+
136
+ args = parser.parse_args()
137
+ return args
138
+
139
+ def main():
140
+ args = options()
141
+ torch.backends.cudnn.deterministic = True
142
+
143
+ testset = RegistrationData('PointNetLK', ModelNet40Data(train=False, num_points=args.num_points),
144
+ partial_template=args.partial_template, partial_source=args.partial_source,
145
+ noise=args.noise, additional_params={'use_masknet': True, 'partial_point_cloud_method': 'planar_crop'})
146
+ test_loader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
147
+
148
+ if not torch.cuda.is_available():
149
+ args.device = 'cpu'
150
+ args.device = torch.device(args.device)
151
+
152
+ # Load Pretrained MaskNet.
153
+ model = MaskNet2()
154
+ if args.pretrained:
155
+ assert os.path.isfile(args.pretrained)
156
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
157
+ model = model.to(args.device)
158
+
159
+ test(args, model, test_loader)
160
+
161
+ if __name__ == '__main__':
162
+ main()
thirdparty/learning3d/examples/test_pcn.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # author: Vinit Sarode (vinitsarode5@gmail.com) 03/23/2020
2
+
3
+ import open3d as o3d
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import logging
8
+ import numpy
9
+ import numpy as np
10
+ import torch
11
+ import torch.utils.data
12
+ import torchvision
13
+ from torch.utils.data import DataLoader
14
+ from tensorboardX import SummaryWriter
15
+ from tqdm import tqdm
16
+
17
+ # Only if the files are in example folder.
18
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
+ if BASE_DIR[-8:] == 'examples':
20
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
21
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
22
+
23
+ from learning3d.models import PCN
24
+ from learning3d.data_utils import ModelNet40Data, ClassificationData
25
+ from learning3d.losses import ChamferDistanceLoss
26
+
27
+ def display_open3d(input_pc, output):
28
+ input_pc_ = o3d.geometry.PointCloud()
29
+ output_ = o3d.geometry.PointCloud()
30
+ input_pc_.points = o3d.utility.Vector3dVector(input_pc)
31
+ output_.points = o3d.utility.Vector3dVector(output + np.array([1,0,0]))
32
+ input_pc_.paint_uniform_color([1, 0, 0])
33
+ output_.paint_uniform_color([0, 1, 0])
34
+ o3d.visualization.draw_geometries([input_pc_, output_])
35
+
36
+ def test_one_epoch(device, model, test_loader):
37
+ model.eval()
38
+ test_loss = 0.0
39
+ pred = 0.0
40
+ count = 0
41
+ for i, data in enumerate(tqdm(test_loader)):
42
+ points, _ = data
43
+
44
+ points = points.to(device)
45
+
46
+ output = model(points)
47
+ loss_val = ChamferDistanceLoss()(points, output['coarse_output'])
48
+ print("Loss Val: ", loss_val)
49
+ display_open3d(points[0].detach().cpu().numpy(), output['coarse_output'][0].detach().cpu().numpy())
50
+
51
+ test_loss += loss_val.item()
52
+ count += 1
53
+
54
+ test_loss = float(test_loss)/count
55
+ return test_loss
56
+
57
+ def test(args, model, test_loader):
58
+ test_loss = test_one_epoch(args.device, model, test_loader)
59
+
60
+ def options():
61
+ parser = argparse.ArgumentParser(description='Point Completion Network')
62
+ parser.add_argument('--exp_name', type=str, default='exp_pcn', metavar='N',
63
+ help='Name of the experiment')
64
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
65
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
66
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
67
+
68
+ # settings for input data
69
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
70
+ metavar='DATASET', help='dataset type (default: modelnet)')
71
+ parser.add_argument('--num_points', default=1024, type=int,
72
+ metavar='N', help='points in point-cloud (default: 1024)')
73
+
74
+ # settings for PCN
75
+ parser.add_argument('--emb_dims', default=1024, type=int,
76
+ metavar='K', help='dim. of the feature vector (default: 1024)')
77
+ parser.add_argument('--detailed_output', default=False, type=bool,
78
+ help='Coarse + Fine Output')
79
+
80
+ # settings for on training
81
+ parser.add_argument('--seed', type=int, default=1234)
82
+ parser.add_argument('-j', '--workers', default=4, type=int,
83
+ metavar='N', help='number of data loading workers (default: 4)')
84
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
85
+ metavar='N', help='mini-batch size (default: 32)')
86
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_pcn/models/best_model.t7', type=str,
87
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
88
+ parser.add_argument('--device', default='cuda:0', type=str,
89
+ metavar='DEVICE', help='use CUDA if available')
90
+
91
+ args = parser.parse_args()
92
+ return args
93
+
94
+ def main():
95
+ args = options()
96
+ args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
97
+
98
+ trainset = ClassificationData(ModelNet40Data(train=True))
99
+ testset = ClassificationData(ModelNet40Data(train=False))
100
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
101
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
102
+
103
+ if not torch.cuda.is_available():
104
+ args.device = 'cpu'
105
+ args.device = torch.device(args.device)
106
+
107
+ # Create PointNet Model.
108
+ model = PCN(emb_dims=args.emb_dims, detailed_output=args.detailed_output)
109
+
110
+ if args.pretrained:
111
+ assert os.path.isfile(args.pretrained)
112
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
113
+ model.to(args.device)
114
+
115
+ test(args, model, test_loader)
116
+
117
+ if __name__ == '__main__':
118
+ main()
thirdparty/learning3d/examples/test_pcrnet.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import PointNet, iPCRNet
22
+ from learning3d.losses import ChamferDistanceLoss
23
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
24
+
25
+
26
+ def display_open3d(template, source, transformed_source):
27
+ template_ = o3d.geometry.PointCloud()
28
+ source_ = o3d.geometry.PointCloud()
29
+ transformed_source_ = o3d.geometry.PointCloud()
30
+ template_.points = o3d.utility.Vector3dVector(template)
31
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
32
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
33
+ template_.paint_uniform_color([1, 0, 0])
34
+ source_.paint_uniform_color([0, 1, 0])
35
+ transformed_source_.paint_uniform_color([0, 0, 1])
36
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
37
+
38
+ def test_one_epoch(device, model, test_loader):
39
+ model.eval()
40
+ test_loss = 0.0
41
+ pred = 0.0
42
+ count = 0
43
+ for i, data in enumerate(tqdm(test_loader)):
44
+ template, source, igt = data
45
+
46
+ template = template.to(device)
47
+ source = source.to(device)
48
+ igt = igt.to(device)
49
+
50
+ output = model(template, source)
51
+ display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
52
+ loss_val = ChamferDistanceLoss()(template, output['transformed_source'])
53
+
54
+ test_loss += loss_val.item()
55
+ count += 1
56
+
57
+ test_loss = float(test_loss)/count
58
+ return test_loss
59
+
60
+ def test(args, model, test_loader):
61
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
62
+
63
+
64
+ def options():
65
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
66
+ parser.add_argument('--exp_name', type=str, default='exp_ipcrnet', metavar='N',
67
+ help='Name of the experiment')
68
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
69
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
70
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
71
+
72
+ # settings for input data
73
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
74
+ metavar='DATASET', help='dataset type (default: modelnet)')
75
+ parser.add_argument('--num_points', default=1024, type=int,
76
+ metavar='N', help='points in point-cloud (default: 1024)')
77
+
78
+ # settings for PointNet
79
+ parser.add_argument('--emb_dims', default=1024, type=int,
80
+ metavar='K', help='dim. of the feature vector (default: 1024)')
81
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
82
+ help='symmetric function (default: max)')
83
+
84
+ # settings for on training
85
+ parser.add_argument('-j', '--workers', default=4, type=int,
86
+ metavar='N', help='number of data loading workers (default: 4)')
87
+ parser.add_argument('-b', '--batch_size', default=20, type=int,
88
+ metavar='N', help='mini-batch size (default: 32)')
89
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_ipcrnet/models/best_model.t7', type=str,
90
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
91
+ parser.add_argument('--device', default='cuda:0', type=str,
92
+ metavar='DEVICE', help='use CUDA if available')
93
+
94
+ args = parser.parse_args()
95
+ return args
96
+
97
+ def main():
98
+ args = options()
99
+
100
+ testset = RegistrationData('PCRNet', ModelNet40Data(train=False))
101
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
102
+
103
+ if not torch.cuda.is_available():
104
+ args.device = 'cpu'
105
+ args.device = torch.device(args.device)
106
+
107
+ # Create PointNet Model.
108
+ ptnet = PointNet(emb_dims=args.emb_dims)
109
+ model = iPCRNet(feature_model=ptnet)
110
+ model = model.to(args.device)
111
+
112
+ if args.pretrained:
113
+ assert os.path.isfile(args.pretrained)
114
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
115
+ model.to(args.device)
116
+
117
+ test(args, model, test_loader)
118
+
119
+ if __name__ == '__main__':
120
+ main()
thirdparty/learning3d/examples/test_pnlk.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import PointNet, PointNetLK
22
+ from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
23
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
24
+
25
+ def display_open3d(template, source, transformed_source):
26
+ template_ = o3d.geometry.PointCloud()
27
+ source_ = o3d.geometry.PointCloud()
28
+ transformed_source_ = o3d.geometry.PointCloud()
29
+ template_.points = o3d.utility.Vector3dVector(template)
30
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
31
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
32
+ template_.paint_uniform_color([1, 0, 0])
33
+ source_.paint_uniform_color([0, 1, 0])
34
+ transformed_source_.paint_uniform_color([0, 0, 1])
35
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
36
+
37
+ def test_one_epoch(device, model, test_loader):
38
+ model.eval()
39
+ test_loss = 0.0
40
+ pred = 0.0
41
+ count = 0
42
+ for i, data in enumerate(tqdm(test_loader)):
43
+ template, source, igt = data
44
+
45
+ template = template.to(device)
46
+ source = source.to(device)
47
+ igt = igt.to(device)
48
+
49
+ output = model(template, source)
50
+
51
+ display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
52
+ loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
53
+
54
+ test_loss += loss_val.item()
55
+ count += 1
56
+
57
+ test_loss = float(test_loss)/count
58
+ return test_loss
59
+
60
+ def test(args, model, test_loader):
61
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
62
+
63
+
64
+ def options():
65
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
66
+ parser.add_argument('--exp_name', type=str, default='exp_pnlk_v1', metavar='N',
67
+ help='Name of the experiment')
68
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
69
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
70
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
71
+
72
+ # settings for input data
73
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
74
+ metavar='DATASET', help='dataset type (default: modelnet)')
75
+ parser.add_argument('--num_points', default=1024, type=int,
76
+ metavar='N', help='points in point-cloud (default: 1024)')
77
+
78
+ # settings for PointNet
79
+ parser.add_argument('--emb_dims', default=1024, type=int,
80
+ metavar='K', help='dim. of the feature vector (default: 1024)')
81
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
82
+ help='symmetric function (default: max)')
83
+
84
+ # settings for on training
85
+ parser.add_argument('--seed', type=int, default=1234)
86
+ parser.add_argument('-j', '--workers', default=4, type=int,
87
+ metavar='N', help='number of data loading workers (default: 4)')
88
+ parser.add_argument('-b', '--batch_size', default=10, type=int,
89
+ metavar='N', help='mini-batch size (default: 32)')
90
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_pnlk/models/best_model.t7', type=str,
91
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
92
+ parser.add_argument('--device', default='cuda:0', type=str,
93
+ metavar='DEVICE', help='use CUDA if available')
94
+
95
+ args = parser.parse_args()
96
+ return args
97
+
98
+ def main():
99
+ args = options()
100
+
101
+ testset = RegistrationData('PointNetLK', ModelNet40Data(train=False))
102
+ test_loader = DataLoader(testset, batch_size=8, shuffle=False, drop_last=False, num_workers=args.workers)
103
+
104
+ if not torch.cuda.is_available():
105
+ args.device = 'cpu'
106
+ args.device = torch.device(args.device)
107
+
108
+ # Create PointNet Model.
109
+ ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
110
+ model = PointNetLK(feature_model=ptnet)
111
+ model = model.to(args.device)
112
+
113
+ if args.pretrained:
114
+ assert os.path.isfile(args.pretrained)
115
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
116
+ model.to(args.device)
117
+
118
+ test(args, model, test_loader)
119
+
120
+ if __name__ == '__main__':
121
+ main()
thirdparty/learning3d/examples/test_pointconv.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import create_pointconv
22
+ from learning3d.models import Classifier
23
+ from learning3d.data_utils import ClassificationData, ModelNet40Data
24
+
25
+ def display_open3d(template):
26
+ template_ = o3d.geometry.PointCloud()
27
+ template_.points = o3d.utility.Vector3dVector(template)
28
+ # template_.paint_uniform_color([1, 0, 0])
29
+ o3d.visualization.draw_geometries([template_])
30
+
31
+ def test_one_epoch(device, model, test_loader, testset):
32
+ model.eval()
33
+ test_loss = 0.0
34
+ pred = 0.0
35
+ count = 0
36
+ for i, data in enumerate(tqdm(test_loader)):
37
+ points, target = data
38
+ target = target[:,0]
39
+
40
+ points = points.to(device)
41
+ target = target.to(device)
42
+
43
+ output = model(points)
44
+ loss_val = torch.nn.functional.nll_loss(
45
+ torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
46
+ print("Ground Truth Label: ", testset.get_shape(target[0].item()))
47
+ print("Predicted Label: ", testset.get_shape(torch.argmax(output[0]).item()))
48
+ display_open3d(points.detach().cpu().numpy()[0])
49
+
50
+ test_loss += loss_val.item()
51
+ count += output.size(0)
52
+
53
+ _, pred1 = output.max(dim=1)
54
+ ag = (pred1 == target)
55
+ am = ag.sum()
56
+ pred += am.item()
57
+
58
+ test_loss = float(test_loss)/count
59
+ accuracy = float(pred)/count
60
+ return test_loss, accuracy
61
+
62
+ def test(args, model, test_loader, testset):
63
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader, testset)
64
+
65
+ def options():
66
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
67
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
68
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
69
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
70
+
71
+ # settings for input data
72
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
73
+ metavar='DATASET', help='dataset type (default: modelnet)')
74
+ parser.add_argument('--num_points', default=1024, type=int,
75
+ metavar='N', help='points in point-cloud (default: 1024)')
76
+
77
+ # settings for PointNet
78
+ parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
79
+ help='train pointnet (default: tune)')
80
+ parser.add_argument('-j', '--workers', default=4, type=int,
81
+ metavar='N', help='number of data loading workers (default: 4)')
82
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
83
+ metavar='N', help='mini-batch size (default: 32)')
84
+ parser.add_argument('--emb_dims', default=1024, type=int,
85
+ metavar='K', help='dim. of the feature vector (default: 1024)')
86
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
87
+ help='symmetric function (default: max)')
88
+
89
+ # settings for on training
90
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_classifier/models/best_model.t7', type=str,
91
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
92
+ parser.add_argument('--device', default='cuda:0', type=str,
93
+ metavar='DEVICE', help='use CUDA if available')
94
+
95
+ args = parser.parse_args()
96
+ return args
97
+
98
+ def main():
99
+ args = options()
100
+ args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
101
+
102
+ testset = ClassificationData(ModelNet40Data(train=False))
103
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
104
+
105
+ if not torch.cuda.is_available():
106
+ args.device = 'cpu'
107
+ args.device = torch.device(args.device)
108
+
109
+ # To use pretrained model provided by authors.
110
+ # PointConv = create_pointconv(classifier=True, pretrained='path of pretrained model.')
111
+ # model = PointConv(emb_dims=args.emb_dims, classifier=True, pretrained='path of pretrained model.')
112
+
113
+ # To use your own pretrained model.
114
+ PointConv = create_pointconv(classifier=False, pretrained=None)
115
+ ptconv = PointConv(emb_dims=args.emb_dims, classifier=True, pretrained=None)
116
+ model = Classifier(feature_model=ptconv)
117
+
118
+ if args.pretrained:
119
+ assert os.path.isfile(args.pretrained)
120
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
121
+ model.to(args.device)
122
+
123
+ test(args, model, test_loader, testset)
124
+
125
+ if __name__ == '__main__':
126
+ main()
thirdparty/learning3d/examples/test_pointnet.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import PointNet
22
+ from learning3d.models import Classifier
23
+ from learning3d.data_utils import ClassificationData, ModelNet40Data
24
+
25
+ def display_open3d(template):
26
+ template_ = o3d.geometry.PointCloud()
27
+ template_.points = o3d.utility.Vector3dVector(template)
28
+ # template_.paint_uniform_color([1, 0, 0])
29
+ o3d.visualization.draw_geometries([template_])
30
+
31
+ def test_one_epoch(device, model, test_loader, testset):
32
+ model.eval()
33
+ test_loss = 0.0
34
+ pred = 0.0
35
+ count = 0
36
+ for i, data in enumerate(tqdm(test_loader)):
37
+ points, target = data
38
+ target = target[:,0]
39
+
40
+ points = points.to(device)
41
+ target = target.to(device)
42
+
43
+ output = model(points)
44
+ loss_val = torch.nn.functional.nll_loss(
45
+ torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
46
+ print("Ground Truth Label: ", testset.get_shape(target[0].item()))
47
+ print("Predicted Label: ", testset.get_shape(torch.argmax(output[0]).item()))
48
+ display_open3d(points.detach().cpu().numpy()[0])
49
+
50
+ test_loss += loss_val.item()
51
+ count += output.size(0)
52
+
53
+ _, pred1 = output.max(dim=1)
54
+ ag = (pred1 == target)
55
+ am = ag.sum()
56
+ pred += am.item()
57
+
58
+ test_loss = float(test_loss)/count
59
+ accuracy = float(pred)/count
60
+ return test_loss, accuracy
61
+
62
+ def test(args, model, test_loader, testset):
63
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader, testset)
64
+
65
+ def options():
66
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
67
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
68
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
69
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
70
+
71
+ # settings for input data
72
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
73
+ metavar='DATASET', help='dataset type (default: modelnet)')
74
+ parser.add_argument('--num_points', default=1024, type=int,
75
+ metavar='N', help='points in point-cloud (default: 1024)')
76
+
77
+ # settings for PointNet
78
+ parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
79
+ help='train pointnet (default: tune)')
80
+ parser.add_argument('-j', '--workers', default=4, type=int,
81
+ metavar='N', help='number of data loading workers (default: 4)')
82
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
83
+ metavar='N', help='mini-batch size (default: 32)')
84
+ parser.add_argument('--emb_dims', default=1024, type=int,
85
+ metavar='K', help='dim. of the feature vector (default: 1024)')
86
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
87
+ help='symmetric function (default: max)')
88
+
89
+ # settings for on training
90
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_classifier/models/best_model.t7', type=str,
91
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
92
+ parser.add_argument('--device', default='cuda:0', type=str,
93
+ metavar='DEVICE', help='use CUDA if available')
94
+
95
+ args = parser.parse_args()
96
+ return args
97
+
98
+ def main():
99
+ args = options()
100
+ args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
101
+
102
+ testset = ClassificationData(ModelNet40Data(train=False))
103
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
104
+
105
+ if not torch.cuda.is_available():
106
+ args.device = 'cpu'
107
+ args.device = torch.device(args.device)
108
+
109
+ # Create PointNet Model.
110
+ ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
111
+ model = Classifier(feature_model=ptnet)
112
+
113
+ if args.pretrained:
114
+ assert os.path.isfile(args.pretrained)
115
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
116
+ model.to(args.device)
117
+
118
+ test(args, model, test_loader, testset)
119
+
120
+ if __name__ == '__main__':
121
+ main()
thirdparty/learning3d/examples/test_prnet.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import PRNet
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def get_transformations(igt):
25
+ R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
26
+ translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
27
+ R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
28
+ translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
29
+ return R_ab, translation_ab, R_ba, translation_ba
30
+
31
+ def display_open3d(template, source, transformed_source):
32
+ template_ = o3d.geometry.PointCloud()
33
+ source_ = o3d.geometry.PointCloud()
34
+ transformed_source_ = o3d.geometry.PointCloud()
35
+ template_.points = o3d.utility.Vector3dVector(template)
36
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
37
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
38
+ template_.paint_uniform_color([1, 0, 0])
39
+ source_.paint_uniform_color([0, 1, 0])
40
+ transformed_source_.paint_uniform_color([0, 0, 1])
41
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
42
+
43
+ def test_one_epoch(device, model, test_loader):
44
+ model.eval()
45
+ test_loss = 0.0
46
+ pred = 0.0
47
+ count = 0
48
+ for i, data in enumerate(tqdm(test_loader)):
49
+ template, source, igt = data
50
+
51
+ transformations = get_transformations(igt)
52
+ transformations = [t.to(device) for t in transformations]
53
+ R_ab, translation_ab, R_ba, translation_ba = transformations
54
+
55
+ template = template.to(device)
56
+ source = source.to(device)
57
+ igt = igt.to(device)
58
+
59
+ output = model(template, source, R_ab, translation_ab.squeeze(2))
60
+ display_open3d(template.detach().cpu().numpy()[0], source.detach().cpu().numpy()[0], output['transformed_source'].detach().cpu().numpy()[0])
61
+
62
+ test_loss += output['loss'].item()
63
+ count += 1
64
+
65
+ test_loss = float(test_loss)/count
66
+ return test_loss
67
+
68
+ def test(args, model, test_loader):
69
+ test_loss = test_one_epoch(args.device, model, test_loader)
70
+
71
+ def options():
72
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
73
+ parser.add_argument('--exp_name', type=str, default='exp_prnet', metavar='N',
74
+ help='Name of the experiment')
75
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
76
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
77
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
78
+
79
+ # settings for input data
80
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
81
+ metavar='DATASET', help='dataset type (default: modelnet)')
82
+
83
+ # settings for PointNet
84
+ parser.add_argument('--emb_dims', default=512, type=int,
85
+ metavar='K', help='dim. of the feature vector (default: 1024)')
86
+ parser.add_argument('--num_iterations', default=3, type=int,
87
+ help='Number of Iterations')
88
+
89
+ parser.add_argument('-j', '--workers', default=4, type=int,
90
+ metavar='N', help='number of data loading workers (default: 4)')
91
+ parser.add_argument('-b', '--batch_size', default=1, type=int,
92
+ metavar='N', help='mini-batch size (default: 32)')
93
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_prnet/models/best_model.t7', type=str,
94
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
95
+ parser.add_argument('--device', default='cuda:0', type=str,
96
+ metavar='DEVICE', help='use CUDA if available')
97
+
98
+ args = parser.parse_args()
99
+ return args
100
+
101
+ def main():
102
+ args = options()
103
+ torch.backends.cudnn.deterministic = True
104
+
105
+ trainset = RegistrationData('PRNet', ModelNet40Data(train=True), partial_source=True, partial_template=True)
106
+ testset = RegistrationData('PRNet', ModelNet40Data(train=False), partial_source=True, partial_template=True)
107
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
108
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
109
+
110
+ if not torch.cuda.is_available():
111
+ args.device = 'cpu'
112
+ args.device = torch.device(args.device)
113
+
114
+ # Create PointNet Model.
115
+ model = PRNet(emb_dims=args.emb_dims, num_iters=args.num_iterations)
116
+ model = model.to(args.device)
117
+
118
+ if args.pretrained:
119
+ assert os.path.isfile(args.pretrained)
120
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'), strict=False)
121
+ model.to(args.device)
122
+
123
+ test(args, model, test_loader)
124
+
125
+ if __name__ == '__main__':
126
+ main()
thirdparty/learning3d/examples/test_rpmnet.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import RPMNet, PPFNet
22
+ from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
23
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
24
+
25
+ def display_open3d(template, source, transformed_source):
26
+ template_ = o3d.geometry.PointCloud()
27
+ source_ = o3d.geometry.PointCloud()
28
+ transformed_source_ = o3d.geometry.PointCloud()
29
+ template_.points = o3d.utility.Vector3dVector(template)
30
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
31
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
32
+ template_.paint_uniform_color([1, 0, 0])
33
+ source_.paint_uniform_color([0, 1, 0])
34
+ transformed_source_.paint_uniform_color([0, 0, 1])
35
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
36
+
37
+ def test_one_epoch(device, model, test_loader):
38
+ model.eval()
39
+ test_loss = 0.0
40
+ pred = 0.0
41
+ count = 0
42
+ for i, data in enumerate(tqdm(test_loader)):
43
+ template, source, igt = data
44
+
45
+ template = template.to(device)
46
+ source = source.to(device)
47
+ igt = igt.to(device)
48
+
49
+ output = model(template, source)
50
+
51
+ display_open3d(template.detach().cpu().numpy()[0,:,:3], source.detach().cpu().numpy()[0,:,:3], output['transformed_source'].detach().cpu().numpy()[0])
52
+ loss_val = FrobeniusNormLoss()(output['est_T'], igt)
53
+
54
+ test_loss += loss_val.item()
55
+ count += 1
56
+
57
+ test_loss = float(test_loss)/count
58
+ return test_loss
59
+
60
+ def test(args, model, test_loader):
61
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
62
+
63
+
64
+ def options():
65
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
66
+ parser.add_argument('--exp_name', type=str, default='exp_rpmnet', metavar='N',
67
+ help='Name of the experiment')
68
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
69
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
70
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
71
+
72
+ # settings for input data
73
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
74
+ metavar='DATASET', help='dataset type (default: modelnet)')
75
+ parser.add_argument('--num_points', default=1024, type=int,
76
+ metavar='N', help='points in point-cloud (default: 1024)')
77
+
78
+ # settings for PointNet
79
+ parser.add_argument('--emb_dims', default=1024, type=int,
80
+ metavar='K', help='dim. of the feature vector (default: 1024)')
81
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
82
+ help='symmetric function (default: max)')
83
+
84
+ # settings for on training
85
+ parser.add_argument('--seed', type=int, default=1234)
86
+ parser.add_argument('-j', '--workers', default=4, type=int,
87
+ metavar='N', help='number of data loading workers (default: 4)')
88
+ parser.add_argument('-b', '--batch_size', default=10, type=int,
89
+ metavar='N', help='mini-batch size (default: 32)')
90
+ parser.add_argument('--pretrained', default='learning3d/pretrained/exp_rpmnet/models/partial-trained.pth', type=str,
91
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
92
+ parser.add_argument('--device', default='cuda:0', type=str,
93
+ metavar='DEVICE', help='use CUDA if available')
94
+
95
+ args = parser.parse_args()
96
+ return args
97
+
98
+ def main():
99
+ args = options()
100
+
101
+ testset = RegistrationData('RPMNet', ModelNet40Data(train=False, num_points=args.num_points, use_normals=True), partial_source=True, partial_template=False)
102
+ test_loader = DataLoader(testset, batch_size=1, shuffle=False, drop_last=False, num_workers=args.workers)
103
+
104
+ if not torch.cuda.is_available():
105
+ args.device = 'cpu'
106
+ args.device = torch.device(args.device)
107
+
108
+ # Create RPMNet Model.
109
+ model = RPMNet(feature_model=PPFNet())
110
+ model = model.to(args.device)
111
+
112
+ if args.pretrained:
113
+ assert os.path.isfile(args.pretrained)
114
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu')['state_dict'])
115
+ model.to(args.device)
116
+
117
+ test(args, model, test_loader)
118
+
119
+ if __name__ == '__main__':
120
+ main()
thirdparty/learning3d/examples/train_PointNetLK.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import PointNet
21
+ from learning3d.models import PointNetLK
22
+ from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
23
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
24
+
25
+ def _init_(args):
26
+ if not os.path.exists('checkpoints'):
27
+ os.makedirs('checkpoints')
28
+ if not os.path.exists('checkpoints/' + args.exp_name):
29
+ os.makedirs('checkpoints/' + args.exp_name)
30
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
31
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
32
+ os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup')
33
+ os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
34
+
35
+
36
+ class IOStream:
37
+ def __init__(self, path):
38
+ self.f = open(path, 'a')
39
+
40
+ def cprint(self, text):
41
+ print(text)
42
+ self.f.write(text + '\n')
43
+ self.f.flush()
44
+
45
+ def close(self):
46
+ self.f.close()
47
+
48
+ def test_one_epoch(device, model, test_loader):
49
+ model.eval()
50
+ test_loss = 0.0
51
+ pred = 0.0
52
+ count = 0
53
+ for i, data in enumerate(tqdm(test_loader)):
54
+ template, source, igt = data
55
+
56
+ template = template.to(device)
57
+ source = source.to(device)
58
+ igt = igt.to(device)
59
+
60
+ output = model(template, source)
61
+ loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
62
+
63
+ test_loss += loss_val.item()
64
+ count += 1
65
+
66
+ test_loss = float(test_loss)/count
67
+ return test_loss
68
+
69
+ def test(args, model, test_loader, textio):
70
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
71
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
72
+
73
+ def train_one_epoch(device, model, train_loader, optimizer):
74
+ model.train()
75
+ train_loss = 0.0
76
+ pred = 0.0
77
+ count = 0
78
+ for i, data in enumerate(tqdm(train_loader)):
79
+ template, source, igt = data
80
+
81
+ template = template.to(device)
82
+ source = source.to(device)
83
+ igt = igt.to(device)
84
+
85
+ output = model(template, source)
86
+ loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
87
+ # print(loss_val.item())
88
+
89
+ # forward + backward + optimize
90
+ optimizer.zero_grad()
91
+ loss_val.backward()
92
+ optimizer.step()
93
+
94
+ train_loss += loss_val.item()
95
+ count += 1
96
+
97
+ train_loss = float(train_loss)/count
98
+ return train_loss
99
+
100
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
101
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
102
+ if args.optimizer == 'Adam':
103
+ optimizer = torch.optim.Adam(learnable_params)
104
+ else:
105
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
106
+
107
+ if checkpoint is not None:
108
+ min_loss = checkpoint['min_loss']
109
+ optimizer.load_state_dict(checkpoint['optimizer'])
110
+
111
+ best_test_loss = np.inf
112
+
113
+ for epoch in range(args.start_epoch, args.epochs):
114
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
115
+ test_loss = test_one_epoch(args.device, model, test_loader)
116
+
117
+ if test_loss<best_test_loss:
118
+ best_test_loss = test_loss
119
+ snap = {'epoch': epoch + 1,
120
+ 'model': model.state_dict(),
121
+ 'min_loss': best_test_loss,
122
+ 'optimizer' : optimizer.state_dict(),}
123
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
124
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
125
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
126
+
127
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
128
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
129
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
130
+
131
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
132
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
133
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
134
+
135
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
136
+
137
+ def options():
138
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
139
+ parser.add_argument('--exp_name', type=str, default='exp_pnlk', metavar='N',
140
+ help='Name of the experiment')
141
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
142
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
143
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
144
+
145
+ # settings for input data
146
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
147
+ metavar='DATASET', help='dataset type (default: modelnet)')
148
+ parser.add_argument('--num_points', default=1024, type=int,
149
+ metavar='N', help='points in point-cloud (default: 1024)')
150
+
151
+ # settings for PointNet
152
+ parser.add_argument('--fine_tune_pointnet', default='tune', type=str, choices=['fixed', 'tune'],
153
+ help='train pointnet (default: tune)')
154
+ parser.add_argument('--transfer_ptnet_weights', default='./checkpoints/exp_classifier/models/best_ptnet_model.t7', type=str,
155
+ metavar='PATH', help='path to pointnet features file')
156
+ parser.add_argument('--emb_dims', default=1024, type=int,
157
+ metavar='K', help='dim. of the feature vector (default: 1024)')
158
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
159
+ help='symmetric function (default: max)')
160
+
161
+ # settings for on training
162
+ parser.add_argument('--seed', type=int, default=1234)
163
+ parser.add_argument('-j', '--workers', default=4, type=int,
164
+ metavar='N', help='number of data loading workers (default: 4)')
165
+ parser.add_argument('-b', '--batch_size', default=10, type=int,
166
+ metavar='N', help='mini-batch size (default: 32)')
167
+ parser.add_argument('--epochs', default=200, type=int,
168
+ metavar='N', help='number of total epochs to run')
169
+ parser.add_argument('--start_epoch', default=0, type=int,
170
+ metavar='N', help='manual epoch number (useful on restarts)')
171
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
172
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
173
+ parser.add_argument('--resume', default='', type=str,
174
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
175
+ parser.add_argument('--pretrained', default='', type=str,
176
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
177
+ parser.add_argument('--device', default='cuda:0', type=str,
178
+ metavar='DEVICE', help='use CUDA if available')
179
+
180
+ args = parser.parse_args()
181
+ return args
182
+
183
+ def main():
184
+ args = options()
185
+
186
+ torch.backends.cudnn.deterministic = True
187
+ torch.manual_seed(args.seed)
188
+ torch.cuda.manual_seed_all(args.seed)
189
+ np.random.seed(args.seed)
190
+
191
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
192
+ _init_(args)
193
+
194
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
195
+ textio.cprint(str(args))
196
+
197
+
198
+ trainset = RegistrationData('PointNetLK', ModelNet40Data(train=True))
199
+ testset = RegistrationData('PointNetLK', ModelNet40Data(train=False))
200
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
201
+ test_loader = DataLoader(testset, batch_size=8, shuffle=False, drop_last=False, num_workers=args.workers)
202
+
203
+ if not torch.cuda.is_available():
204
+ args.device = 'cpu'
205
+ args.device = torch.device(args.device)
206
+
207
+ # Create PointNet Model.
208
+ ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
209
+
210
+ if args.transfer_ptnet_weights and os.path.isfile(args.transfer_ptnet_weights):
211
+ ptnet.load_state_dict(torch.load(args.transfer_ptnet_weights, map_location='cpu'))
212
+
213
+ if args.fine_tune_pointnet == 'tune':
214
+ pass
215
+ elif args.fine_tune_pointnet == 'fixed':
216
+ for param in ptnet.parameters():
217
+ param.requires_grad_(False)
218
+
219
+ model = PointNetLK(feature_model=ptnet)
220
+ model = model.to(args.device)
221
+
222
+ checkpoint = None
223
+ if args.resume:
224
+ assert os.path.isfile(args.resume)
225
+ checkpoint = torch.load(args.resume)
226
+ args.start_epoch = checkpoint['epoch']
227
+ model.load_state_dict(checkpoint['model'])
228
+
229
+ if args.pretrained:
230
+ assert os.path.isfile(args.pretrained)
231
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
232
+ model.to(args.device)
233
+
234
+ if args.eval:
235
+ test(args, model, test_loader, textio)
236
+ else:
237
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
238
+
239
+ if __name__ == '__main__':
240
+ main()
thirdparty/learning3d/examples/train_dcp.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import DGCNN, DCP
21
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
22
+
23
+ def _init_(args):
24
+ if not os.path.exists('checkpoints'):
25
+ os.makedirs('checkpoints')
26
+ if not os.path.exists('checkpoints/' + args.exp_name):
27
+ os.makedirs('checkpoints/' + args.exp_name)
28
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
29
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
30
+ os.system('cp train_dcp.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
31
+
32
+ class IOStream:
33
+ def __init__(self, path):
34
+ self.f = open(path, 'a')
35
+
36
+ def cprint(self, text):
37
+ print(text)
38
+ self.f.write(text + '\n')
39
+ self.f.flush()
40
+
41
+ def close(self):
42
+ self.f.close()
43
+
44
+ def get_transformations(igt):
45
+ R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
46
+ translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
47
+ R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
48
+ translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
49
+ return R_ab, translation_ab, R_ba, translation_ba
50
+
51
+ def test_one_epoch(device, model, test_loader):
52
+ model.eval()
53
+ test_loss = 0.0
54
+ pred = 0.0
55
+ count = 0
56
+ for i, data in enumerate(tqdm(test_loader)):
57
+ template, source, igt = data
58
+ transformations = get_transformations(igt)
59
+ transformations = [t.to(device) for t in transformations]
60
+ R_ab, translation_ab, R_ba, translation_ba = transformations
61
+
62
+ template = template.to(device)
63
+ source = source.to(device)
64
+ igt = igt.to(device)
65
+
66
+ output = model(template, source)
67
+ identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
68
+ loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
69
+ + torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
70
+
71
+ cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
72
+ + torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
73
+ loss_val = loss_val + cycle_loss * 0.1
74
+
75
+ test_loss += loss_val.item()
76
+ count += 1
77
+
78
+ test_loss = float(test_loss)/count
79
+ return test_loss
80
+
81
+ def test(args, model, test_loader, textio):
82
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
83
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
84
+
85
+ def train_one_epoch(device, model, train_loader, optimizer):
86
+ model.train()
87
+ train_loss = 0.0
88
+ pred = 0.0
89
+ count = 0
90
+ for i, data in enumerate(tqdm(train_loader)):
91
+ template, source, igt = data
92
+ transformations = get_transformations(igt)
93
+ transformations = [t.to(device) for t in transformations]
94
+ R_ab, translation_ab, R_ba, translation_ba = transformations
95
+
96
+ template = template.to(device)
97
+ source = source.to(device)
98
+ igt = igt.to(device)
99
+
100
+ output = model(template, source)
101
+ identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
102
+ loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
103
+ + torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
104
+
105
+ cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
106
+ + torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
107
+ loss_val = loss_val + cycle_loss * 0.1
108
+ # print(loss_val.item())
109
+
110
+ # forward + backward + optimize
111
+ optimizer.zero_grad()
112
+ loss_val.backward()
113
+ optimizer.step()
114
+
115
+ train_loss += loss_val.item()
116
+ count += 1
117
+
118
+ train_loss = float(train_loss)/count
119
+ return train_loss
120
+
121
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
122
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
123
+ if args.optimizer == 'Adam':
124
+ optimizer = torch.optim.Adam(learnable_params)
125
+ else:
126
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
127
+
128
+ if checkpoint is not None:
129
+ min_loss = checkpoint['min_loss']
130
+ optimizer.load_state_dict(checkpoint['optimizer'])
131
+
132
+ best_test_loss = np.inf
133
+
134
+ for epoch in range(args.start_epoch, args.epochs):
135
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
136
+ test_loss = test_one_epoch(args.device, model, test_loader)
137
+
138
+ if test_loss<best_test_loss:
139
+ best_test_loss = test_loss
140
+ snap = {'epoch': epoch + 1,
141
+ 'model': model.state_dict(),
142
+ 'min_loss': best_test_loss,
143
+ 'optimizer' : optimizer.state_dict(),}
144
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
145
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
146
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
147
+
148
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
149
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
150
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
151
+
152
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
153
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
154
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
155
+
156
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
157
+
158
+ def options():
159
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
160
+ parser.add_argument('--exp_name', type=str, default='exp_dcp', metavar='N',
161
+ help='Name of the experiment')
162
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
163
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
164
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
165
+
166
+ # settings for input data
167
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
168
+ metavar='DATASET', help='dataset type (default: modelnet)')
169
+ parser.add_argument('--num_points', default=1024, type=int,
170
+ metavar='N', help='points in point-cloud (default: 1024)')
171
+
172
+ # settings for PointNet
173
+ parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
174
+ help='train pointnet (default: tune)')
175
+ parser.add_argument('--emb_dims', default=1024, type=int,
176
+ metavar='K', help='dim. of the feature vector (default: 1024)')
177
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
178
+ help='symmetric function (default: max)')
179
+
180
+ # settings for on training
181
+ parser.add_argument('--seed', type=int, default=1234)
182
+ parser.add_argument('-j', '--workers', default=4, type=int,
183
+ metavar='N', help='number of data loading workers (default: 4)')
184
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
185
+ metavar='N', help='mini-batch size (default: 32)')
186
+ parser.add_argument('--epochs', default=200, type=int,
187
+ metavar='N', help='number of total epochs to run')
188
+ parser.add_argument('--start_epoch', default=0, type=int,
189
+ metavar='N', help='manual epoch number (useful on restarts)')
190
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
191
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
192
+ parser.add_argument('--resume', default='', type=str,
193
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
194
+ parser.add_argument('--pretrained', default='', type=str,
195
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
196
+ parser.add_argument('--device', default='cuda:0', type=str,
197
+ metavar='DEVICE', help='use CUDA if available')
198
+
199
+ args = parser.parse_args()
200
+ return args
201
+
202
+ def main():
203
+ args = options()
204
+
205
+ torch.backends.cudnn.deterministic = True
206
+ torch.manual_seed(args.seed)
207
+ torch.cuda.manual_seed_all(args.seed)
208
+ np.random.seed(args.seed)
209
+
210
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
211
+ _init_(args)
212
+
213
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
214
+ textio.cprint(str(args))
215
+
216
+
217
+ trainset = RegistrationData('DCP', ModelNet40Data(train=True))
218
+ testset = RegistrationData('DCP', ModelNet40Data(train=False))
219
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
220
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
221
+
222
+ if not torch.cuda.is_available():
223
+ args.device = 'cpu'
224
+ args.device = torch.device(args.device)
225
+
226
+ # Create PointNet Model.
227
+ dgcnn = DGCNN(emb_dims=args.emb_dims)
228
+ model = DCP(feature_model=dgcnn, cycle=True)
229
+ model = model.to(args.device)
230
+
231
+ checkpoint = None
232
+ if args.resume:
233
+ assert os.path.isfile(args.resume)
234
+ checkpoint = torch.load(args.resume)
235
+ args.start_epoch = checkpoint['epoch']
236
+ model.load_state_dict(checkpoint['model'])
237
+
238
+ if args.pretrained:
239
+ assert os.path.isfile(args.pretrained)
240
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
241
+ model.to(args.device)
242
+
243
+ if args.eval:
244
+ test(args, model, test_loader, textio)
245
+ else:
246
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
247
+
248
+ if __name__ == '__main__':
249
+ main()
thirdparty/learning3d/examples/train_deepgmr.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import DeepGMR
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def display_open3d(template, source, transformed_source):
25
+ template_ = o3d.geometry.PointCloud()
26
+ source_ = o3d.geometry.PointCloud()
27
+ transformed_source_ = o3d.geometry.PointCloud()
28
+ template_.points = o3d.utility.Vector3dVector(template)
29
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
30
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
31
+ template_.paint_uniform_color([1, 0, 0])
32
+ source_.paint_uniform_color([0, 1, 0])
33
+ transformed_source_.paint_uniform_color([0, 0, 1])
34
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
35
+
36
+ def rotation_error(R, R_gt):
37
+ cos_theta = (torch.einsum('bij,bij->b', R, R_gt) - 1) / 2
38
+ cos_theta = torch.clamp(cos_theta, -1, 1)
39
+ return torch.acos(cos_theta) * 180 / math.pi
40
+
41
+ def translation_error(t, t_gt):
42
+ return torch.norm(t - t_gt, dim=1)
43
+
44
+ def rmse(pts, T, T_gt):
45
+ pts_pred = pts @ T[:, :3, :3].transpose(1, 2) + T[:, :3, 3].unsqueeze(1)
46
+ pts_gt = pts @ T_gt[:, :3, :3].transpose(1, 2) + T_gt[:, :3, 3].unsqueeze(1)
47
+ return torch.norm(pts_pred - pts_gt, dim=2).mean(dim=1)
48
+
49
+ def test_one_epoch(device, model, test_loader):
50
+ model.eval()
51
+ test_loss = 0.0
52
+ pred = 0.0
53
+ count = 0
54
+ rotation_errors, translation_errors, rmses = [], [], []
55
+
56
+ for i, data in enumerate(tqdm(test_loader)):
57
+ template, source, igt = data
58
+
59
+ template = template.to(device)
60
+ source = source.to(device)
61
+ igt = igt.to(device)
62
+
63
+ output = model(template, source)
64
+
65
+ eye = torch.eye(4).expand_as(igt).to(igt.device)
66
+ mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
67
+ mse2 = F.mse_loss(output['est_T'] @ igt, eye)
68
+ loss = mse1 + mse2
69
+
70
+ r_err = rotation_error(est_T_inverse[:, :3, :3], igt[:, :3, :3])
71
+ t_err = translation_error(est_T_inverse[:, :3, 3], igt[:, :3, 3])
72
+ rmse_val = rmse(template[:, :100], est_T_inverse, igt)
73
+ rotation_errors.append(r_err)
74
+ translation_errors.append(t_err)
75
+ rmses.append(rmse_val)
76
+
77
+ test_loss += loss_val.item()
78
+ count += 1
79
+
80
+ test_loss = float(test_loss)/count
81
+ print("Mean rotation error: {}, Mean translation error: {} and Mean RMSE: {}".format(np.mean(rotation_errors), np.mean(translation_errors), np.mean(rmses)))
82
+ return test_loss
83
+
84
+ def test(args, model, test_loader, textio):
85
+ test_loss = test_one_epoch(args.device, model, test_loader)
86
+ textio.cprint('Validation Loss: %f'%(test_loss))
87
+
88
+ def train_one_epoch(device, model, train_loader, optimizer):
89
+ model.train()
90
+ train_loss = 0.0
91
+ pred = 0.0
92
+ count = 0
93
+ for i, data in enumerate(tqdm(train_loader)):
94
+ template, source, igt = data
95
+
96
+ template = template.to(device)
97
+ source = source.to(device)
98
+ igt = igt.to(device)
99
+
100
+ output = model(template, source)
101
+
102
+ eye = torch.eye(4).expand_as(igt).to(igt.device)
103
+ mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
104
+ mse2 = F.mse_loss(output['est_T'] @ igt, eye)
105
+ loss = mse1 + mse2
106
+
107
+ # forward + backward + optimize
108
+ optimizer.zero_grad()
109
+ loss_val.backward()
110
+ optimizer.step()
111
+
112
+ train_loss += loss_val.item()
113
+ count += 1
114
+
115
+ train_loss = float(train_loss)/count
116
+ return train_loss
117
+
118
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
119
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
120
+ if args.optimizer == 'Adam':
121
+ optimizer = torch.optim.Adam(learnable_params)
122
+ else:
123
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
124
+
125
+ if checkpoint is not None:
126
+ min_loss = checkpoint['min_loss']
127
+ optimizer.load_state_dict(checkpoint['optimizer'])
128
+
129
+ best_test_loss = np.inf
130
+
131
+ for epoch in range(args.start_epoch, args.epochs):
132
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
133
+ test_loss = test_one_epoch(args.device, model, test_loader)
134
+
135
+ if test_loss<best_test_loss:
136
+ best_test_loss = test_loss
137
+ snap = {'epoch': epoch + 1,
138
+ 'model': model.state_dict(),
139
+ 'min_loss': best_test_loss,
140
+ 'optimizer' : optimizer.state_dict(),}
141
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
142
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
143
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
144
+
145
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
146
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
147
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
148
+
149
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
150
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
151
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
152
+
153
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
154
+
155
+ def options():
156
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
157
+ parser.add_argument('--exp_name', type=str, default='exp_deepgmr', metavar='N',
158
+ help='Name of the experiment')
159
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
160
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
161
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
162
+
163
+ # settings for input data
164
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
165
+ metavar='DATASET', help='dataset type (default: modelnet)')
166
+ parser.add_argument('--num_points', default=1024, type=int,
167
+ metavar='N', help='points in point-cloud (default: 1024)')
168
+
169
+ parser.add_argument('--nearest_neighbors', default=20, type=int,
170
+ metavar='K', help='No of nearest neighbors to be estimated.')
171
+ parser.add_argument('--use_rri', default=True, type=bool,
172
+ help='Find nearest neighbors to estimate features from PointNet.')
173
+
174
+ # settings for on training
175
+ parser.add_argument('-j', '--workers', default=4, type=int,
176
+ metavar='N', help='number of data loading workers (default: 4)')
177
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
178
+ metavar='N', help='mini-batch size (default: 32)')
179
+ parser.add_argument('--pretrained', default='', type=str,
180
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
181
+ parser.add_argument('--device', default='cuda:0', type=str,
182
+ metavar='DEVICE', help='use CUDA if available')
183
+ parser.add_argument('--epochs', default=200, type=int,
184
+ metavar='N', help='number of total epochs to run')
185
+ parser.add_argument('--start_epoch', default=0, type=int,
186
+ metavar='N', help='manual epoch number (useful on restarts)')
187
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
188
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
189
+ parser.add_argument('--resume', default='', type=str,
190
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
191
+ parser.add_argument('--pretrained', default='', type=str,
192
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
193
+ parser.add_argument('--device', default='cuda:0', type=str,
194
+ metavar='DEVICE', help='use CUDA if available')
195
+
196
+ args = parser.parse_args()
197
+ if args.nearest_neighbors > 0:
198
+ args.use_rri = True
199
+ return args
200
+
201
+ def main():
202
+ args = options()
203
+ torch.backends.cudnn.deterministic = True
204
+ torch.manual_seed(args.seed)
205
+ torch.cuda.manual_seed_all(args.seed)
206
+ np.random.seed(args.seed)
207
+
208
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
209
+ _init_(args)
210
+
211
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
212
+ textio.cprint(str(args))
213
+
214
+ trainset = RegistrationData('DeepGMR', ModelNet40Data(train=True), additional_params={'nearest_neighbors': args.nearest_neighbors})
215
+ testset = RegistrationData('DeepGMR', ModelNet40Data(train=False), additional_params={'nearest_neighbors': args.nearest_neighbors})
216
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
217
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
218
+
219
+ if not torch.cuda.is_available():
220
+ args.device = 'cpu'
221
+ args.device = torch.device(args.device)
222
+
223
+ model = DeepGMR(use_rri=args.use_rri, nearest_neighbors=args.nearest_neighbors)
224
+ model = model.to(args.device)
225
+
226
+ checkpoint = None
227
+ if args.resume:
228
+ assert os.path.isfile(args.resume)
229
+ checkpoint = torch.load(args.resume)
230
+ args.start_epoch = checkpoint['epoch']
231
+ model.load_state_dict(checkpoint['model'])
232
+
233
+ if args.pretrained:
234
+ assert os.path.isfile(args.pretrained)
235
+ model.load_state_dict(torch.load(args.pretrained), strict=False)
236
+ model.to(args.device)
237
+
238
+ if args.eval:
239
+ test(args, model, test_loader, textio)
240
+ else:
241
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
242
+
243
+ if __name__ == '__main__':
244
+ main()