diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b10163f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,14 @@
+deps/*
+!deps/.gitkeep
+outputs/*
+!outputs/.gitkeep
+src_shot/build/*
+**/__pycache__
+license/*
+!license/.gitkeep
+*.tar
+*.so
+.TimeRecord
+imgui.ini
+temp_data/*
+!temp_data/.gitkeep
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..d8a1dd6
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 RPMArt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..aa78ded
--- /dev/null
+++ b/README.md
@@ -0,0 +1,156 @@
+# RPMArt
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Official implementation for the paper [RPMArt: Towards Robust Perception and Manipulation for Articulated Objects](https://arxiv.org/abs/2403.16023), accepted by [IROS 2024](https://iros2024-abudhabi.org).
+
+For more information, please visit our [project website](https://r-pmart.github.io/).
+
+---
+
+## 🛠 Installation
+### 💻 Server-side
+1. Clone this repo.
+ ```bash
+ git clone git@github.com:R-PMArt/rpmart.git
+ cd rpmart
+ ```
+
+2. Create a [Conda](https://conda.org/) environment.
+ ```bash
+ conda create -n rpmart python=3.8
+ conda activate rpmart
+ ```
+
+3. Install [PyTorch](https://pytorch.org/).
+ ```bash
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
+ ```
+
+4. Install [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr).
+ ```bash
+ cd deps
+ git clone git@github.com:ildoonet/pytorch-gradual-warmup-lr.git
+ cd pytorch-gradual-warmup-lr
+ pip install .
+ cd ../..
+ ```
+
+5. Install [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine).
+ ```bash
+ conda install openblas-devel -c anaconda
+ export CUDA_HOME=/usr/local/cuda
+ pip install ninja
+ cd deps
+ git clone git@github.com:NVIDIA/MinkowskiEngine.git
+ cd MinkowskiEngine
+ pip install -U . --no-deps --install-option="--blas_include_dirs=${CONDA_PREFIX}/include" --install-option="--blas=openblas"
+ cd ../..
+ ```
+
+6. Install [CuPy](https://cupy.dev/).
+ ```bash
+ pip install cupy-cuda11x
+ ```
+
+7. Install special [SAPIEN](https://sapien.ucsd.edu/).
+ ```bash
+ pip install http://download.cs.stanford.edu/orion/where2act/where2act_sapien_wheels/sapien-0.8.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
+ ```
+
+8. Install [AnyGrasp](https://github.com/graspnet/anygrasp_sdk).
+ ```bash
+ pip install cvxopt munch graspnetAPI
+ # follow AnyGrasp to use the licenses and weights and binary codes
+ ```
+
+9. Install other dependencies.
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+10. Build `shot`.
+ ```bash
+ # you may need first install pybind11 and pcl
+ cd src_shot
+ mkdir build
+ cd build
+ cmake ..
+ make
+ cd ../..
+ ```
+
+### 🦾 Robot-side
+1. Make sure `rt-linux` is enabled for Franka Emika Panda.
+ ```bash
+ uname -a
+ ```
+
+2. Install `frankx` for robot and `pyrealsense2` for camera.
+ ```bash
+ pip install frankx pyrealsense2
+ ```
+
+3. Install `paramiko` for connecting with server.
+ ```bash
+ pip install paramiko
+ ```
+
+## 🏃♂️ Run
+1. Train or download [RoArtNet](https://huggingface.co/dadadadawjb/RoArtNet).
+ ```bash
+ bash scripts/train.sh
+ ```
+
+2. Test RoArtNet.
+ ```bash
+ bash scripts/test.sh
+ ```
+
+3. Evaluate RPMArt.
+ ```bash
+ bash scripts/eval_roartnet.sh
+ ```
+
+4. Test RoArtNet on [RealArt-6](https://huggingface.co/datasets/dadadadawjb/RealArt-6).
+ ```bash
+ bash scripts/test_real.sh
+ ```
+
+5. Evaluate RPMArt in the real world.
+ ```bash
+ # server side
+ bash scripts/real_service.sh
+
+ # robot side
+ python real_eval.py
+ ```
+
+## 🙏 Acknowledgement
+* Our simulation environment is adapted from [VAT-Mart](https://github.com/warshallrho/VAT-Mart).
+* Our voting module is adapted from [CPPF](https://github.com/qq456cvb/CPPF) and [CPPF++](https://github.com/qq456cvb/CPPF2).
+
+## ✍ Citation
+If you find our work useful, please consider citing:
+```bibtex
+@article{wang2024rpmart,
+ title={RPMArt: Towards Robust Perception and Manipulation for Articulated Objects},
+ author={Wang, Junbo and Liu, Wenhai and Yu, Qiaojun and You, Yang and Liu, Liu and Wang, Weiming and Lu, Cewu},
+ journal={arXiv preprint arXiv:2403.16023},
+ year={2024}
+}
+```
+
+## 📃 License
+This repository is released under the [MIT](https://mit-license.org/) license.
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000..9ea5386
Binary files /dev/null and b/assets/teaser.png differ
diff --git a/configs/.gitkeep b/configs/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/configs/algorithm/formal.yaml b/configs/algorithm/formal.yaml
new file mode 100644
index 0000000..5f6aaed
--- /dev/null
+++ b/configs/algorithm/formal.yaml
@@ -0,0 +1,29 @@
+sampling:
+ sample_tuples_num: 100000
+ tuple_more_num: 3
+
+shot_encoder:
+ hidden_dims: [128, 128, 128, 128, 128]
+ feature_dim: 64
+ bn: False
+ ln: False # bn and ln can only be set one
+ dropout: 0
+
+encoder:
+ hidden_dims: [128, 128, 128, 128, 128]
+ bn: False
+ ln: False # bn and ln can only be set one
+ dropout: 0
+
+voting:
+ rot_bin_num: 36 # 5 degree
+ voting_num: 120 # 3 degree
+ angle_tol: 1.5 # 10 degree
+ # angle_tol: 0.35 # 5 degree
+ translation2pc: False
+ rotation_cluster: False
+ multi_candidate: False
+ candidate_threshold: 0.5
+ rotation_multi_neighbor: False
+ neighbor_threshold: 10
+ bmm_size: 100000
diff --git a/configs/data/camera_config.json b/configs/data/camera_config.json
new file mode 100644
index 0000000..a04d979
--- /dev/null
+++ b/configs/data/camera_config.json
@@ -0,0 +1,15 @@
+{
+ "intrinsics": {
+ "fovx": [70, 70],
+ "fovy": [40, 60],
+ "height": 480,
+ "width": 640,
+ "near": 0.1,
+ "far": 100.0
+ },
+ "extrinsics": {
+ "dist": [0.6, 1.2],
+ "phi": [0, 60],
+ "theta": [120, 240]
+ }
+}
\ No newline at end of file
diff --git a/configs/data/object_config.json b/configs/data/object_config.json
new file mode 100644
index 0000000..f8c2608
--- /dev/null
+++ b/configs/data/object_config.json
@@ -0,0 +1,30 @@
+{
+ "scale_min": 0.8,
+ "scale_max": 1.1,
+ "Microwave": {
+ "size": 0.4,
+ "joint_num": 1
+ },
+ "StorageFurniture": {
+ "size": 0.375,
+ "joint_num": 2,
+ "joint_types": ["prismatic", "revolute"]
+ },
+ "Refrigerator": {
+ "size": 0.4,
+ "joint_num": 1
+ },
+ "Safe": {
+ "size": 0.4,
+ "joint_num": 1
+ },
+ "WashingMachine": {
+ "size": 0.15,
+ "joint_num": 1
+ },
+ "Drawer": {
+ "size": 0.4,
+ "joint_num": 3,
+ "joint_types": ["prismatic", "prismatic", "prismatic"]
+ }
+}
\ No newline at end of file
diff --git a/configs/dataset/formal_drawer.yaml b/configs/dataset/formal_drawer.yaml
new file mode 100644
index 0000000..47ea190
--- /dev/null
+++ b/configs/dataset/formal_drawer.yaml
@@ -0,0 +1,13 @@
+train_path: '/data2/junbo/sapien4/train'
+test_path: '/data2/junbo/sapien4/test'
+train_categories: ['Drawer']
+test_categories: ['Drawer']
+joint_num: 3
+resolution: 2.5e-2
+# resolution: 1e-2
+receptive_field: 10
+normalize: 'bound'
+# normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/formal_microwave.yaml b/configs/dataset/formal_microwave.yaml
new file mode 100644
index 0000000..183ad92
--- /dev/null
+++ b/configs/dataset/formal_microwave.yaml
@@ -0,0 +1,13 @@
+train_path: '/data2/junbo/sapien4/train'
+test_path: '/data2/junbo/sapien4/test'
+train_categories: ['Microwave']
+test_categories: ['Microwave']
+joint_num: 1
+resolution: 2.5e-2
+# resolution: 1e-2
+receptive_field: 10
+normalize: 'bound'
+# normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/formal_refrigerator.yaml b/configs/dataset/formal_refrigerator.yaml
new file mode 100644
index 0000000..137d540
--- /dev/null
+++ b/configs/dataset/formal_refrigerator.yaml
@@ -0,0 +1,13 @@
+train_path: '/data2/junbo/sapien4/train'
+test_path: '/data2/junbo/sapien4/test'
+train_categories: ['Refrigerator']
+test_categories: ['Refrigerator']
+joint_num: 1
+resolution: 2.5e-2
+# resolution: 1e-2
+receptive_field: 10
+normalize: 'bound'
+# normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/formal_safe.yaml b/configs/dataset/formal_safe.yaml
new file mode 100644
index 0000000..48996ab
--- /dev/null
+++ b/configs/dataset/formal_safe.yaml
@@ -0,0 +1,13 @@
+train_path: '/data2/junbo/sapien4/train'
+test_path: '/data2/junbo/sapien4/test'
+train_categories: ['Safe']
+test_categories: ['Safe']
+joint_num: 1
+resolution: 2.5e-2
+# resolution: 1e-2
+receptive_field: 10
+normalize: 'bound'
+# normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/formal_storagefurniture.yaml b/configs/dataset/formal_storagefurniture.yaml
new file mode 100644
index 0000000..679acb6
--- /dev/null
+++ b/configs/dataset/formal_storagefurniture.yaml
@@ -0,0 +1,13 @@
+train_path: '/data2/junbo/sapien4/train'
+test_path: '/data2/junbo/sapien4/test'
+train_categories: ['StorageFurniture']
+test_categories: ['StorageFurniture']
+joint_num: 2
+resolution: 2.5e-2
+# resolution: 1e-2
+receptive_field: 10
+normalize: 'bound'
+# normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/formal_washingmachine.yaml b/configs/dataset/formal_washingmachine.yaml
new file mode 100644
index 0000000..1c53fef
--- /dev/null
+++ b/configs/dataset/formal_washingmachine.yaml
@@ -0,0 +1,13 @@
+train_path: '/data2/junbo/sapien4/train'
+test_path: '/data2/junbo/sapien4/test'
+train_categories: ['WashingMachine']
+test_categories: ['WashingMachine']
+joint_num: 1
+resolution: 2.5e-2
+# resolution: 1e-2
+receptive_field: 10
+normalize: 'bound'
+# normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/real_with_with.yaml b/configs/dataset/real_with_with.yaml
new file mode 100644
index 0000000..cf0e622
--- /dev/null
+++ b/configs/dataset/real_with_with.yaml
@@ -0,0 +1,11 @@
+path: '/data2/junbo/RealArt-6/with_table/microwave'
+instances: ['0_with_chaos', '1_with_chaos', '2_with_chaos', '3_with_chaos', '4_with_chaos']
+joint_num: 1
+# resolution: 2.5e-2
+resolution: 1e-2
+receptive_field: 10
+# normalize: 'bound'
+normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/real_with_without.yaml b/configs/dataset/real_with_without.yaml
new file mode 100644
index 0000000..b0151c6
--- /dev/null
+++ b/configs/dataset/real_with_without.yaml
@@ -0,0 +1,11 @@
+path: '/data2/junbo/RealArt-6/with_table/microwave'
+instances: ['0_without_chaos', '1_without_chaos', '2_without_chaos', '3_without_chaos', '4_without_chaos']
+joint_num: 1
+# resolution: 2.5e-2
+resolution: 1e-2
+receptive_field: 10
+# normalize: 'bound'
+normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/real_without_with.yaml b/configs/dataset/real_without_with.yaml
new file mode 100644
index 0000000..2602e7b
--- /dev/null
+++ b/configs/dataset/real_without_with.yaml
@@ -0,0 +1,11 @@
+path: '/data2/junbo/RealArt-6/without_table/microwave'
+instances: ['0_with_chaos', '1_with_chaos', '2_with_chaos', '3_with_chaos', '4_with_chaos']
+joint_num: 1
+# resolution: 2.5e-2
+resolution: 1e-2
+receptive_field: 10
+# normalize: 'bound'
+normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/dataset/real_without_without.yaml b/configs/dataset/real_without_without.yaml
new file mode 100644
index 0000000..1960c30
--- /dev/null
+++ b/configs/dataset/real_without_without.yaml
@@ -0,0 +1,11 @@
+path: '/data2/junbo/RealArt-6/without_table/microwave'
+instances: ['0_without_chaos', '1_without_chaos', '2_without_chaos', '3_without_chaos', '4_without_chaos']
+joint_num: 1
+# resolution: 2.5e-2
+resolution: 1e-2
+receptive_field: 10
+# normalize: 'bound'
+normalize: 'none'
+sample_points_num: 1024
+rgb: False
+denoise: False
diff --git a/configs/eval_config.yaml b/configs/eval_config.yaml
new file mode 100644
index 0000000..0a387c5
--- /dev/null
+++ b/configs/eval_config.yaml
@@ -0,0 +1,23 @@
+dataset: ???
+algorithm:
+ voting:
+ voting_num: 120 # 3 degree
+ # angle_tol: 1.5 # 10 degree
+ angle_tol: 0.35 # 5 degree
+ translation2pc: False
+ rotation_cluster: True
+ multi_candidate: True
+ candidate_threshold: 0.5
+ rotation_multi_neighbor: True
+ neighbor_threshold: 5
+ bmm_size: 100000
+
+trained:
+ path: {
+ "Microwave": "./outputs/train/formal_microwave_2024_01_16_12_16_57",
+ "Refrigerator": "./outputs/train/formal_refrigerator_2024_01_17_22_52_26",
+ "Safe": "./outputs/train/formal_safe_2024_01_17_23_01_44",
+ "StorageFurniture": "./outputs/train/formal_storagefurniture_2024_01_18_11_13_43",
+ "Drawer": "./outputs/train/formal_drawer_2024_01_18_11_18_23",
+ "WashingMachine": "./outputs/train/formal_washingmachine_2024_01_17_23_07_14"
+ }
diff --git a/configs/test_config.yaml b/configs/test_config.yaml
new file mode 100644
index 0000000..e97e1df
--- /dev/null
+++ b/configs/test_config.yaml
@@ -0,0 +1,34 @@
+dataset:
+ noise: True
+ distortion_rate: 0.1
+ distortion_level: 0.01
+ outlier_rate: 0.001
+ outlier_level: 0.5
+algorithm:
+ voting:
+ voting_num: 120 # 3 degree
+ # angle_tol: 1.5 # 10 degree
+ angle_tol: 0.35 # 5 degree
+ translation2pc: False
+ rotation_cluster: True
+ multi_candidate: True
+ candidate_threshold: 0.5
+ rotation_multi_neighbor: True
+ neighbor_threshold: 5
+ bmm_size: 100000
+testing:
+ seed: 42
+ device: 0
+ batch_size: 16
+ num_workers: 8
+ training: False
+
+trained:
+ path: "./outputs/train/formal_microwave_2024_01_16_12_16_57"
+
+abbr: 'noise_microwave'
+vis: False
+
+hydra:
+ run:
+ dir: "./outputs/test/${abbr}_${now:%Y_%m_%d_%H_%M_%S}"
diff --git a/configs/test_gt_config.yaml b/configs/test_gt_config.yaml
new file mode 100644
index 0000000..7b665c0
--- /dev/null
+++ b/configs/test_gt_config.yaml
@@ -0,0 +1,38 @@
+general:
+ seed: 42
+ device: 0
+ batch_size: 16
+ num_workers: 8
+ # abbr: 'none_norm_FTTTT'
+ abbr: 'formal_drawer'
+
+dataset:
+ path: "/data2/junbo/sapien4/test"
+ categories: ['Drawer']
+ joint_num: 3
+ # resolution: 2.5e-2
+ resolution: 1e-2
+ receptive_field: 10
+ denoise: False
+ # normalize: 'median'
+ normalize: 'none'
+ sample_points_num: 1024
+
+algorithm:
+ sample_tuples_num: 100000
+ tuple_more_num: 0
+ translation2pc: False
+ rotation_cluster: True
+ multi_candidate: True
+ candidate_threshold: 0.5
+ rotation_multi_neighbor: True
+ neighbor_threshold: 5
+ # angle_tol: 1.5 # 10 degree
+ angle_tol: 0.35 # 5 degree
+ voting_num: 120 # 3 degree
+ bmm_size: 100000
+
+
+hydra:
+ run:
+ dir: "./outputs/test_gt/${general.abbr}_${now:%Y_%m_%d_%H_%M_%S}"
diff --git a/configs/test_real_config.yaml b/configs/test_real_config.yaml
new file mode 100644
index 0000000..3512036
--- /dev/null
+++ b/configs/test_real_config.yaml
@@ -0,0 +1,31 @@
+defaults:
+ - dataset: real_without_without
+ - _self_
+algorithm:
+ voting:
+ voting_num: 120 # 3 degree
+ # angle_tol: 1.5 # 10 degree
+ angle_tol: 0.35 # 5 degree
+ translation2pc: False
+ rotation_cluster: True
+ multi_candidate: True
+ candidate_threshold: 0.5
+ rotation_multi_neighbor: True
+ neighbor_threshold: 5
+ bmm_size: 100000
+testing:
+ seed: 42
+ device: 0
+ batch_size: 16
+ num_workers: 8
+ training: False
+
+trained:
+ path: "./outputs/train/formal_microwave_2024_01_16_12_16_57"
+
+abbr: 'formal_microwave_without_without'
+vis: False
+
+hydra:
+ run:
+ dir: "./outputs/test_real/${abbr}_${now:%Y_%m_%d_%H_%M_%S}"
diff --git a/configs/train_config.yaml b/configs/train_config.yaml
new file mode 100644
index 0000000..f0c2e76
--- /dev/null
+++ b/configs/train_config.yaml
@@ -0,0 +1,12 @@
+defaults:
+ - dataset: formal_microwave
+ - algorithm: formal
+ - training: formal
+ - _self_
+
+abbr: 'formal_microwave'
+debug: False
+
+hydra:
+ run:
+ dir: "./outputs/train/${abbr}_${now:%Y_%m_%d_%H_%M_%S}"
diff --git a/configs/training/formal.yaml b/configs/training/formal.yaml
new file mode 100644
index 0000000..d0670e4
--- /dev/null
+++ b/configs/training/formal.yaml
@@ -0,0 +1,21 @@
+seed: 42
+device: 0
+batch_size: 16
+num_workers: 8
+epoch_num: 200 # microwave
+# epoch_num: 200 # storagefurniture
+# epoch_num: 300 # refrigerator
+# epoch_num: 120 # safe
+# epoch_num: 200 # washingmachine
+# epoch_num: 100 # drawer
+lr: 1e-3
+weight_decay: 0
+lambda_rot: 0.1
+lambda_afford: 1.0
+lambda_conf: 0.5
+val_training: True
+val_training_num: 64
+val_testing: True
+val_testing_num: 64
+test_train: True
+test_test: True
diff --git a/datasets/point_tuple_dataset.py b/datasets/point_tuple_dataset.py
new file mode 100644
index 0000000..8f29dfb
--- /dev/null
+++ b/datasets/point_tuple_dataset.py
@@ -0,0 +1,257 @@
+from typing import List
+import os
+import json
+import time
+import itertools
+from pathlib import Path
+import tqdm
+import random
+import numpy as np
+import torch
+import MinkowskiEngine as ME
+import open3d as o3d
+
+if __name__ == '__main__':
+ import sys
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+from utilities.data_utils import pc_normalize, joints_normalize, farthest_point_sample, generate_target_tr, generate_target_rot, transform_pc, transform_dir
+from utilities.vis_utils import visualize
+from utilities.env_utils import setup_seed
+from utilities.constants import light_blue_color, red_color, dark_red_color, dark_green_color
+from src_shot.build import shot
+
+
+class ArticulationDataset(torch.utils.data.Dataset):
+ def __init__(self, path:str, instances:List[str], joint_num:int, resolution:float, receptive_field:int,
+ sample_points_num:int, sample_tuples_num:int, tuple_more_num:int,
+ rgb:bool, denoise:bool, normalize:str, debug:bool, vis:bool, is_train:bool) -> None:
+ super().__init__()
+ self.path = path
+ self.instances = instances
+ self.joint_num = joint_num
+ self.resolution = resolution
+ self.receptive_field = receptive_field
+ self.sample_points_num = sample_points_num
+ self.sample_tuples_num = sample_tuples_num
+ self.tuple_more_num = tuple_more_num
+ self.rgb = rgb
+ self.denoise = denoise
+ self.normalize = normalize
+ self.debug = debug
+ self.vis = vis
+ self.is_train = is_train
+ self.fns = sorted(list(itertools.chain(*[list(Path(path).glob('{}/*.npz'.format(instance))) for instance in instances])))
+ self.permutations = list(itertools.permutations(range(self.sample_points_num), 2))
+ if debug:
+ print(f"{len(self.fns) =}", f"{self.fns[0] =}")
+
+ def __len__(self):
+ return len(self.fns)
+
+ def __getitem__(self, idx:int):
+ # load data
+ data = np.load(self.fns[idx])
+
+ pc = data['point_cloud'].astype(np.float32) # (N'', 3)
+ if self.debug:
+ print(f"{pc.shape = }")
+ print(f"{idx = }", f"{self.fns[idx] = }")
+
+ if self.rgb:
+ pc_color = data['rgb'].astype(np.float32) # (N'', 3)
+
+ assert data['joints'].shape[0] == self.joint_num
+ joints = data['joints'].astype(np.float32) # (J, 9)
+ if self.debug:
+ print(f"{joints.shape = }")
+
+ joint_translations = joints[:, 0:3].astype(np.float32) # (J, 3)
+ joint_rotations = joints[:, 3:6].astype(np.float32) # (J, 3)
+ affordable_positions = joints[:, 6:9].astype(np.float32) # (J, 3)
+ joint_types = joints[:, -1].astype(np.int64) # (J,)
+ if self.debug:
+ print(f"{joint_translations.shape = }", f"{joint_rotations.shape = }", f"{affordable_positions.shape = }")
+
+ # TODO: transform to sapien coordinate, not compatible with pybullet generated data
+ transform_matrix = np.array([[0, 0, 1, 0],
+ [-1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, 0, 1]])
+ pc = transform_pc(pc, transform_matrix)
+ joint_translations = transform_pc(joint_translations, transform_matrix)
+ joint_rotations = transform_dir(joint_rotations, transform_matrix)
+ affordable_positions = transform_pc(affordable_positions, transform_matrix)
+
+ if self.debug and self.vis:
+ # camera coordinate as (x, y, z) = (right, down, in)
+ # normal as inside/outside both
+ visualize(pc, pc_color=pc_color if self.rgb else light_blue_color, pc_normal=None,
+ joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions,
+ joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color,
+ whether_frame=True, whether_bbox=True, window_name='before')
+
+ # preprocess
+ suffix = 'cache'
+ if self.denoise:
+ suffix += '_denoise'
+ suffix += f'_{self.normalize}'
+ suffix += f'_{str(self.resolution)}'
+ suffix += f'_{str(self.receptive_field)}'
+ suffix += f'_{str(self.sample_points_num)}'
+ preprocessed_path = os.path.join(os.path.dirname(str(self.fns[idx])), suffix, os.path.basename(str(self.fns[idx])))
+ if os.path.exists(preprocessed_path):
+ preprocessed_data = np.load(preprocessed_path)
+ pc = preprocessed_data['pc'].astype(np.float32)
+ pc_normal = preprocessed_data['pc_normal'].astype(np.float32)
+ pc_shot = preprocessed_data['pc_shot'].astype(np.float32)
+ if self.rgb:
+ pc_color = preprocessed_data['pc_color'].astype(np.float32)
+ center = preprocessed_data['center'].astype(np.float32)
+ scale = float(preprocessed_data['scale'])
+ joint_translations, joint_rotations = joints_normalize(joint_translations, joint_rotations, center, scale)
+ affordable_positions, _ = joints_normalize(affordable_positions, None, center, scale)
+ else:
+ start_time = time.time()
+ if self.denoise:
+ # pcd = o3d.geometry.PointCloud()
+ # pcd.points = o3d.utility.Vector3dVector(pc)
+ # _, index = pcd.remove_statistical_outlier(nb_neighbors=100, std_ratio=1.5)
+ # pc = pc[index]
+ # if self.rgb:
+ # pc_color = pc_color[index]
+ valid_mask = pc[:, 0] > 0.05
+ pc = pc[valid_mask]
+ if self.rgb:
+ pc_color = pc_color[valid_mask]
+ end_time = time.time()
+ if self.debug:
+ print(f"denoise: {end_time - start_time}")
+ print(f"{pc.shape = }")
+
+ start_time = time.time()
+ pc, center, scale = pc_normalize(pc, self.normalize)
+ joint_translations, joint_rotations = joints_normalize(joint_translations, joint_rotations, center, scale)
+ affordable_positions, _ = joints_normalize(affordable_positions, None, center, scale)
+ end_time = time.time()
+ if self.debug:
+ print(f"pc_normalize: {end_time - start_time}")
+
+ start_time = time.time()
+ indices = ME.utils.sparse_quantize(np.ascontiguousarray(pc), return_index=True, quantization_size=self.resolution)[1]
+ pc = np.ascontiguousarray(pc[indices].astype(np.float32)) # (N', 3)
+ if self.rgb:
+ pc_color = pc_color[indices] # (N', 3)
+ end_time = time.time()
+ if self.debug:
+ print(f"sparse_quantize: {end_time - start_time}")
+ print(f"{pc.shape = }")
+
+ start_time = time.time()
+ pc_normal = shot.estimate_normal(pc, self.resolution * self.receptive_field).reshape(-1, 3).astype(np.float32)
+ pc_normal[~np.isfinite(pc_normal)] = 0 # (N', 3)
+ end_time = time.time()
+ if self.debug:
+ print(f"estimate_normal: {end_time - start_time}")
+ print(f"{pc_normal.shape = }")
+
+ start_time = time.time()
+ pc_shot = shot.compute(pc, self.resolution * self.receptive_field, self.resolution * self.receptive_field).reshape(-1, 352).astype(np.float32)
+ pc_shot[~np.isfinite(pc_shot)] = 0 # (N', 352)
+ end_time = time.time()
+ if self.debug:
+ print(f"shot: {end_time - start_time}")
+ print(f"{pc_shot.shape = }")
+
+ start_time = time.time()
+ pc, indices = farthest_point_sample(pc, self.sample_points_num) # (N, 3)
+ pc_normal = pc_normal[indices] # (N, 3)
+ pc_shot = pc_shot[indices] # (N, 352)
+ if self.rgb:
+ pc_color = pc_color[indices] # (N, 3)
+ end_time = time.time()
+ if self.debug:
+ print(f"farthest_point_sample: {end_time - start_time}")
+ print(f"{pc.shape = }", f"{pc_normal.shape = }", f"{pc_shot.shape = }")
+
+ if not self.debug:
+ os.makedirs(os.path.dirname(preprocessed_path), exist_ok=True)
+ if self.rgb:
+ np.savez(preprocessed_path, pc=pc, pc_normal=pc_normal, pc_shot=pc_shot, pc_color=pc_color, center=center, scale=scale)
+ else:
+ np.savez(preprocessed_path, pc=pc, pc_normal=pc_normal, pc_shot=pc_shot, center=center, scale=scale)
+
+ if self.debug and self.vis:
+ # camera coordinate as (x, y, z) = (right, down, in)
+ # normal as inside/outside both
+ visualize(pc, pc_color=pc_color if self.rgb else light_blue_color, pc_normal=pc_normal,
+ joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions,
+ joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color,
+ whether_frame=True, whether_bbox=True, window_name='after')
+
+ # sample point tuples
+ start_time = time.time()
+ point_idxs = random.sample(self.permutations, self.sample_tuples_num)
+ point_idxs = np.array(point_idxs, dtype=np.int64) # (N_t, 2)
+ point_idxs_more = np.random.randint(0, self.sample_points_num, size=(self.sample_tuples_num, self.tuple_more_num), dtype=np.int64) # (N_t, N_m)
+ point_idxs_all = np.concatenate([point_idxs, point_idxs_more], axis=-1) # (N_t, 2 + N_m)
+ end_time = time.time()
+ if self.debug:
+ print(f"sample_point_tuples: {end_time - start_time}")
+ print(f"{point_idxs_all.shape = }")
+
+ # generate targets
+ start_time = time.time()
+ targets_tr, targets_rot = [], []
+ for j in range(self.joint_num):
+ target_tr = generate_target_tr(pc, joint_translations[j], point_idxs_all[:, :2]) # (N_t, 2)
+ target_rot = generate_target_rot(pc, joint_rotations[j], point_idxs_all[:, :2]) # (N_t,)
+ targets_tr.append(target_tr)
+ targets_rot.append(target_rot)
+ targets_tr = np.stack(targets_tr, axis=0).astype(np.float32) # (J, N_t, 2)
+ targets_rot = np.stack(targets_rot, axis=0).astype(np.float32) # (J, N_t)
+ end_time = time.time()
+ if self.debug:
+ print(f"generate_targets: {end_time - start_time}")
+ print(f"{targets_tr.shape = }", f"{targets_rot.shape = }")
+
+ if self.is_train:
+ if self.rgb:
+ return pc, pc_normal, pc_shot, pc_color, \
+ targets_tr, targets_rot, \
+ point_idxs_all
+ else:
+ return pc, pc_normal, pc_shot, \
+ targets_tr, targets_rot, \
+ point_idxs_all
+ else:
+ if self.rgb:
+ return pc, pc_normal, pc_shot, pc_color, \
+ joint_translations, joint_rotations, affordable_positions, joint_types, \
+ point_idxs_all, str(self.fns[idx])
+ else:
+ return pc, pc_normal, pc_shot, \
+ joint_translations, joint_rotations, affordable_positions, joint_types, \
+ point_idxs_all, str(self.fns[idx])
+
+
+if __name__ == '__main__':
+ setup_seed(seed=42)
+ path = "/data2/junbo/RealArt-6/without_table/microwave"
+ instances = ['0_without_chaos', '1_without_chaos', '2_without_chaos', '3_without_chaos', '4_without_chaos']
+ joint_num = 1
+ resolution = 1e-2
+ receptive_field = 10
+ sample_points_num = 1024
+ sample_tuples_num = 100000
+ tuple_more_num = 3
+ normalize = 'none'
+ rgb = False
+ denoise = True
+ dataset = ArticulationDataset(path, instances, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ rgb, denoise, normalize, debug=True, vis=True, is_train=False)
+ batch_size = 2
+ num_workers = 0
+ dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+ for results in tqdm.tqdm(dataloader):
+ import pdb; pdb.set_trace()
diff --git a/datasets/rconfmask_afford_point_tuple_dataset.py b/datasets/rconfmask_afford_point_tuple_dataset.py
new file mode 100644
index 0000000..1c00c6d
--- /dev/null
+++ b/datasets/rconfmask_afford_point_tuple_dataset.py
@@ -0,0 +1,260 @@
+from typing import List
+import os
+import time
+import itertools
+from pathlib import Path
+import tqdm
+import random
+import numpy as np
+import torch
+import MinkowskiEngine as ME
+import open3d as o3d
+
+if __name__ == '__main__':
+ import sys
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+from utilities.data_utils import transform_pc, transform_dir, pc_normalize, joints_normalize, farthest_point_sample, generate_target_tr, generate_target_rot, pc_noise
+from utilities.vis_utils import visualize_mask, visualize_confidence_voting
+from utilities.env_utils import setup_seed
+from src_shot.build import shot
+
+
+class ArticulationDataset(torch.utils.data.Dataset):
+ def __init__(self, path:str, categories:List[str], joint_num:int, resolution:float, receptive_field:int,
+ sample_points_num:int, sample_tuples_num:int, tuple_more_num:int,
+ noise:bool, distortion_rate:float, distortion_level:float, outlier_rate:float, outlier_level:float,
+ rgb:bool, denoise:bool, normalize:str, debug:bool, vis:bool, is_train:bool) -> None:
+ super().__init__()
+ self.path = path
+ self.categories = categories
+ self.joint_num = joint_num
+ self.resolution = resolution
+ self.receptive_field = receptive_field
+ self.sample_points_num = sample_points_num
+ self.sample_tuples_num = sample_tuples_num
+ self.tuple_more_num = tuple_more_num
+ self.noise = noise # NOTE: only used in testing
+ self.distortion_rate = distortion_rate
+ self.distortion_level = distortion_level
+ self.outlier_rate = outlier_rate
+ self.outlier_level = outlier_level
+ self.rgb = rgb
+ self.denoise = denoise
+ self.normalize = normalize
+ self.debug = debug
+ self.vis = vis
+ self.is_train = is_train
+ self.fns = sorted(list(itertools.chain(*[list(Path(path).glob('{}*/*.npz'.format(category))) for category in categories])))
+ self.permutations = list(itertools.permutations(range(self.sample_points_num), 2))
+ if debug:
+ print(f"{len(self.fns) =}", f"{self.fns[0] =}")
+
+ def __len__(self):
+ return len(self.fns)
+
+ def __getitem__(self, idx:int):
+ # load data
+ data = np.load(self.fns[idx])
+
+ c2w = data['extrinsic'].astype(np.float32) # (4, 4)
+ w2c = np.linalg.inv(c2w)
+ pc = data['pcd_world'].astype(np.float32) # (N'', 3)
+ pc = transform_pc(pc, w2c)
+ if self.noise:
+ pc = pc_noise(pc, self.distortion_rate, self.distortion_level, self.outlier_rate, self.outlier_level)
+ if self.debug:
+ print(f"{pc.shape = }")
+ print(f"{idx = }", f"{self.fns[idx] = }")
+
+ if self.rgb:
+ pc_color = data['pcd_color'].astype(np.float32) # (N'', 3)
+
+ instance_mask = data['instance_mask'].astype(np.int64) # (N'',)
+ function_mask = data['function_mask'].astype(np.int64) # (N'',)
+
+ joint_translations = data['joint_bases'].astype(np.float32) # (J, 3)
+ joint_translations = transform_pc(joint_translations, w2c)
+ joint_rotations = data['joint_directions'].astype(np.float32) # (J, 3)
+ joint_rotations = transform_dir(joint_rotations, w2c)
+ affordable_positions = data['affordable_positions'].astype(np.float32) # (J, 3)
+ affordable_positions = transform_pc(affordable_positions, w2c)
+ joint_num = joint_translations.shape[0]
+ assert self.joint_num == joint_num
+ if self.debug:
+ print(f"{joint_translations.shape = }", f"{joint_rotations.shape = }", f"{affordable_positions.shape = }")
+
+ if self.debug and self.vis:
+ # camera coordinate as (x, y, z) = (in, left, up)
+ # normal as inside/outside both
+ visualize_mask(pc, instance_mask, function_mask, pc_normal=None,
+ joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions,
+ whether_frame=True, whether_bbox=True, window_name='before')
+
+ # preprocess
+ start_time = time.time()
+ if self.denoise:
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ _, index = pcd.remove_statistical_outlier(nb_neighbors=100, std_ratio=1.5)
+ pc = pc[index]
+ if self.rgb:
+ pc_color = pc_color[index]
+ end_time = time.time()
+ if self.debug:
+ print(f"denoise: {end_time - start_time}")
+ print(f"{pc.shape = }")
+
+ start_time = time.time()
+ pc, center, scale = pc_normalize(pc, self.normalize)
+ joint_translations, joint_rotations = joints_normalize(joint_translations, joint_rotations, center, scale)
+ affordable_positions, _ = joints_normalize(affordable_positions, None, center, scale)
+ end_time = time.time()
+ if self.debug:
+ print(f"pc_normalize: {end_time - start_time}")
+
+ start_time = time.time()
+ indices = ME.utils.sparse_quantize(np.ascontiguousarray(pc), return_index=True, quantization_size=self.resolution)[1]
+ pc = np.ascontiguousarray(pc[indices].astype(np.float32)) # (N', 3)
+ instance_mask = instance_mask[indices] # (N',)
+ function_mask = function_mask[indices] # (N',)
+ if self.rgb:
+ pc_color = pc_color[indices] # (N', 3)
+ end_time = time.time()
+ if self.debug:
+ print(f"sparse_quantize: {end_time - start_time}")
+ print(f"{pc.shape = }")
+
+ start_time = time.time()
+ pc_normal = shot.estimate_normal(pc, self.resolution * self.receptive_field).reshape(-1, 3).astype(np.float32)
+ pc_normal[~np.isfinite(pc_normal)] = 0 # (N', 3)
+ end_time = time.time()
+ if self.debug:
+ print(f"estimate_normal: {end_time - start_time}")
+ print(f"{pc_normal.shape = }")
+
+ start_time = time.time()
+ pc_shot = shot.compute(pc, self.resolution * self.receptive_field, self.resolution * self.receptive_field).reshape(-1, 352).astype(np.float32)
+ pc_shot[~np.isfinite(pc_shot)] = 0 # (N', 352)
+ end_time = time.time()
+ if self.debug:
+ print(f"shot: {end_time - start_time}")
+ print(f"{pc_shot.shape = }")
+
+ start_time = time.time()
+ pc, indices = farthest_point_sample(pc, self.sample_points_num) # (N, 3)
+ pc_normal = pc_normal[indices] # (N, 3)
+ pc_shot = pc_shot[indices] # (N, 352)
+ if self.rgb:
+ pc_color = pc_color[indices] # (N, 3)
+ instance_mask = instance_mask[indices] # (N,)
+ function_mask = function_mask[indices] # (N,)
+ end_time = time.time()
+ if self.debug:
+ print(f"farthest_point_sample: {end_time - start_time}")
+ print(f"{pc.shape = }", f"{pc_normal.shape = }", f"{pc_shot.shape = }")
+
+ if self.debug and self.vis:
+ # camera coordinate as (x, y, z) = (in, left, up)
+ # normal as inside/outside both
+ visualize_mask(pc, instance_mask, function_mask, pc_normal=pc_normal,
+ joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions,
+ whether_frame=True, whether_bbox=True, window_name='after')
+
+ # sample point tuples
+ start_time = time.time()
+ point_idxs = random.sample(self.permutations, self.sample_tuples_num)
+ point_idxs = np.array(point_idxs, dtype=np.int64) # (N_t, 2)
+ point_idxs_more = np.random.randint(0, self.sample_points_num, size=(self.sample_tuples_num, self.tuple_more_num), dtype=np.int64) # (N_t, N_m)
+ point_idxs_all = np.concatenate([point_idxs, point_idxs_more], axis=-1) # (N_t, 2 + N_m)
+ base_mask = (instance_mask == 0)
+ base_idxs = np.where(base_mask)[0]
+ part_idxs = []
+ for j in range(joint_num):
+ part_mask = (instance_mask == j + 1)
+ part_idxs.append(np.where(part_mask)[0])
+ base_idxs_mask = np.isin(point_idxs, base_idxs)
+ base_case_mask = np.all(base_idxs_mask, axis=-1)
+ part_base_case_mask = np.logical_and(np.any(base_idxs_mask, axis=-1), np.logical_not(base_case_mask))
+ part_case_mask = np.logical_not(np.logical_or(base_case_mask, part_base_case_mask))
+ end_time = time.time()
+ if self.debug:
+ print(f"sample_point_tuples: {end_time - start_time}")
+ print(f"{point_idxs.shape = }, {point_idxs_all.shape = }")
+ if self.debug and self.vis:
+ visualize_confidence_voting(np.ones((self.sample_tuples_num,)), pc, point_idxs, whether_frame=True, whether_bbox=True, window_name='point tuples')
+
+ # generate targets
+ start_time = time.time()
+ targets_tr = np.zeros((joint_num, self.sample_tuples_num, 2), dtype=np.float32) # (J, N_t, 2)
+ targets_rot = np.zeros((joint_num, self.sample_tuples_num), dtype=np.float32) # (J, N_t)
+ targets_afford = np.zeros((joint_num, self.sample_tuples_num, 2), dtype=np.float32) # (J, N_t, 2)
+ targets_conf = np.zeros((joint_num, self.sample_tuples_num), dtype=np.float32) # (J, N_t)
+ for j in range(joint_num):
+ this_part_mask = np.any(np.isin(point_idxs[part_base_case_mask], part_idxs[j]), axis=-1)
+ same_part_mask = np.all(np.isin(point_idxs[part_case_mask], part_idxs[j]), axis=-1)
+ merge_this_part_mask = part_base_case_mask.copy()
+ merge_this_part_mask[part_base_case_mask] = this_part_mask.copy()
+ merge_same_part_mask = part_case_mask.copy()
+ merge_same_part_mask[part_case_mask] = same_part_mask.copy()
+ targets_tr[j] = generate_target_tr(pc, joint_translations[j], point_idxs)
+ targets_rot[j] = generate_target_rot(pc, joint_rotations[j], point_idxs)
+ targets_afford[j] = generate_target_tr(pc, affordable_positions[j], point_idxs)
+ # targets_conf[j, merge_same_part_mask] = 0.51
+ targets_conf[j, merge_this_part_mask] = 1.0
+ end_time = time.time()
+ if self.debug:
+ print(f"generate_targets: {end_time - start_time}")
+ print(f"{targets_tr.shape = }", f"{targets_rot.shape = }", f"{targets_conf.shape = }", f"{targets_afford.shape = }")
+
+ if self.is_train:
+ if self.rgb:
+ return pc, pc_normal, pc_shot, pc_color, \
+ targets_tr, targets_rot, targets_afford, targets_conf, \
+ point_idxs_all
+ else:
+ return pc, pc_normal, pc_shot, \
+ targets_tr, targets_rot, targets_afford, targets_conf, \
+ point_idxs_all
+ else:
+ # actually during testing, the targets should not be known, here they are used to test gt
+ if self.rgb:
+ return pc, pc_normal, pc_shot, pc_color, \
+ joint_translations, joint_rotations, affordable_positions, \
+ targets_tr, targets_rot, targets_afford, targets_conf, \
+ point_idxs_all
+ else:
+ return pc, pc_normal, pc_shot, \
+ joint_translations, joint_rotations, affordable_positions, \
+ targets_tr, targets_rot, targets_afford, targets_conf, \
+ point_idxs_all
+
+
+if __name__ == '__main__':
+ setup_seed(seed=42)
+ path = "/data2/junbo/sapien4/test"
+ categories = ['Microwave']
+ joint_num = 1
+ # resolution = 2.5e-2
+ resolution = 1e-2
+ receptive_field = 10
+ sample_points_num = 1024
+ sample_tuples_num = 100000
+ tuple_more_num = 3
+ # normalize = 'median'
+ normalize = 'none'
+ noise = True
+ distortion_rate = 0.1
+ distortion_level = 0.01
+ outlier_rate = 0.001
+ outlier_level = 1.0
+ rgb = True
+ denoise = False
+ dataset = ArticulationDataset(path, categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ noise, distortion_rate, distortion_level, outlier_rate, outlier_level,
+ rgb, denoise, normalize, debug=True, vis=False, is_train=True)
+ batch_size = 1
+ num_workers = 0
+ dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+ for results in tqdm.tqdm(dataloader):
+ import pdb; pdb.set_trace()
diff --git a/deps/.gitkeep b/deps/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/envs/camera.py b/envs/camera.py
new file mode 100644
index 0000000..a2c229a
--- /dev/null
+++ b/envs/camera.py
@@ -0,0 +1,209 @@
+"""
+Modified from https://github.com/warshallrho/VAT-Mart/blob/main/code/camera.py
+"""
+import numpy as np
+from sapien.core import Pose
+
+from .env import Env
+
+class Camera(object):
+
+ def __init__(self, env:Env, near=0.1, far=100.0, image_size=448, dist=5.0, \
+ phi=np.pi/5, theta=np.pi, fov=35.0, random_position=False, fixed_position=False, restrict_dir=False, real_data=False):
+ builder = env.scene.create_actor_builder()
+ camera_mount_actor = builder.build(is_kinematic=True)
+ self.env = env
+
+ # set camera intrinsics
+ if isinstance(image_size, int):
+ width = image_size
+ height = image_size
+ else:
+ width = image_size[0]
+ height = image_size[1]
+ if isinstance(fov, float):
+ fovx = 0
+ fovy = fov
+ else:
+ fovx = fov[0]
+ fovy = fov[1]
+ self.camera = env.scene.add_mounted_camera('camera', camera_mount_actor, Pose(), \
+ width, height, np.deg2rad(fovx), np.deg2rad(fovy), near, far)
+
+ # set camera extrinsics
+ if random_position:
+ if restrict_dir:
+ theta = (0.25 + np.random.random() * 0.5) * np.pi*2
+ phi = (np.random.random()+1) * np.pi/6
+ else:
+ theta = np.random.random() * np.pi*2
+ phi = (np.random.random()+1) * np.pi/6
+ if fixed_position:
+ theta = np.pi
+ phi = np.pi/10
+ if random_position and real_data:
+ theta = np.random.random() * np.pi*2
+ phi = np.random.random() * np.pi * 2
+
+ pos = np.array([dist*np.cos(phi)*np.cos(theta), \
+ dist*np.cos(phi)*np.sin(theta), \
+ dist*np.sin(phi)])
+ forward = -pos / np.linalg.norm(pos)
+ left = np.cross([0, 0, 1], forward)
+ left = left / np.linalg.norm(left)
+ up = np.cross(forward, left)
+ mat44 = np.eye(4)
+ mat44[:3, :3] = np.vstack([forward, left, up]).T
+ mat44[:3, 3] = pos # mat44 is cam2world
+ mat44[0, 3] += env.object_position_offset
+ self.mat44 = mat44
+ camera_mount_actor.set_pose(Pose.from_transformation_matrix(mat44))
+
+ # log parameters
+ self.near = near
+ self.far = far
+ self.fov = [fovx, fovy]
+ self.dist = dist
+ self.theta = theta
+ self.phi = phi
+ self.pos = pos
+
+ self.camera_mount_actor = camera_mount_actor
+
+ def change_pose(self,
+ phi=np.pi/5, theta=np.pi, dist=5.0, random_position=False, restrict_dir=False):
+ # set camera extrinsics
+ if random_position:
+ if restrict_dir:
+ theta = (0.25 + np.random.random() * 0.5) * np.pi*2
+ phi = (np.random.random()+1) * np.pi/6
+ else:
+ theta = np.random.random() * np.pi*2
+ phi = (np.random.random()+1) * np.pi/6
+ pos = np.array([dist*np.cos(phi)*np.cos(theta), \
+ dist*np.cos(phi)*np.sin(theta), \
+ dist*np.sin(phi)])
+ forward = -pos / np.linalg.norm(pos)
+ left = np.cross([0, 0, 1], forward)
+ left = left / np.linalg.norm(left)
+ up = np.cross(forward, left)
+ mat44 = np.eye(4)
+ mat44[:3, :3] = np.vstack([forward, left, up]).T
+ mat44[:3, 3] = pos # mat44 is cam2world
+ mat44[0, 3] += self.env.object_position_offset
+ self.mat44 = mat44
+ self.camera_mount_actor.set_pose(Pose.from_transformation_matrix(mat44))
+
+ self.dist = dist
+ self.theta = theta
+ self.phi = phi
+ self.pos = pos
+
+ def change_fov(self, fov):
+ if isinstance(fov, float):
+ fovx = 0
+ fovy = fov
+ else:
+ fovx = fov[0]
+ fovy = fov[1]
+ self.camera.set_mode_perspective(np.deg2rad(fovy))
+ self.fov = [fovx, fovy]
+
+ def get_observation(self):
+ self.camera.take_picture()
+ rgba = self.camera.get_color_rgba()
+ rgba = (rgba * 255).clip(0, 255).astype(np.float32) / 255
+ white = np.ones((rgba.shape[0], rgba.shape[1], 3), dtype=np.float32)
+ mask = np.tile(rgba[:, :, 3:4], [1, 1, 3])
+ rgb = rgba[:, :, :3] * mask + white * (1 - mask)
+ depth = self.camera.get_depth().astype(np.float32)
+ return rgb, depth
+
+ def compute_camera_XYZA(self, depth):
+ camera_matrix = self.camera.get_camera_matrix()[:3, :3]
+ y, x = np.where(depth < 1)
+ z = self.near * self.far / (self.far + depth * (self.near - self.far))
+ permutation = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
+ points = (permutation @ np.dot(np.linalg.inv(camera_matrix), \
+ np.stack([x, y, np.ones_like(x)] * z[y, x], 0))).T
+ return y, x, points
+
+ @staticmethod
+ def compute_XYZA_matrix(id1, id2, pts, size1, size2):
+ out = np.zeros((size1, size2, 4), dtype=np.float32)
+ out[id1, id2, :3] = pts
+ out[id1, id2, 3] = 1
+ return out
+
+ def get_normal_map(self):
+ nor = self.camera.get_normal_rgba()
+ # convert from PartNet-space (x-right, y-up, z-backward) to SAPIEN-space (x-front, y-left, z-up)
+ new_nor = np.array(nor, dtype=np.float32)
+ new_nor[:, :, 0] = -nor[:, :, 2]
+ new_nor[:, :, 1] = -nor[:, :, 0]
+ new_nor[:, :, 2] = nor[:, :, 1]
+ return new_nor
+
+ def get_movable_link_mask(self, link_ids):
+ link_seg = self.camera.get_segmentation()
+ link_mask = np.zeros((link_seg.shape[0], link_seg.shape[1])).astype(np.uint8)
+ for idx, lid in enumerate(link_ids):
+ cur_link_pixels = int(np.sum(link_seg==lid))
+ if cur_link_pixels > 0:
+ link_mask[link_seg == lid] = idx+1
+ return link_mask
+
+ def get_handle_mask(self):
+ # read part seg partid2renderids
+ partid2renderids = dict()
+ for k in self.env.scene.render_id_to_visual_name:
+ if self.env.scene.render_id_to_visual_name[k].split('-')[0] == 'handle':
+ part_id = int(self.env.scene.render_id_to_visual_name[k].split('-')[-1])
+ if part_id not in partid2renderids:
+ partid2renderids[part_id] = []
+ partid2renderids[part_id].append(k)
+ # generate 0/1 handle mask
+ part_seg = self.camera.get_obj_segmentation()
+ handle_mask = np.zeros((part_seg.shape[0], part_seg.shape[1])).astype(np.uint8)
+ for partid in partid2renderids:
+ cur_part_mask = np.isin(part_seg, partid2renderids[partid])
+ cur_part_mask_pixels = int(np.sum(cur_part_mask))
+ if cur_part_mask_pixels > 0:
+ handle_mask[cur_part_mask] = 1
+ return handle_mask
+
+ def get_object_mask(self):
+ rgba = self.camera.get_albedo_rgba()
+ return rgba[:, :, 3] > 0.5
+
+ # return camera parameters
+ def get_metadata(self):
+ return {
+ 'pose': self.camera.get_pose(),
+ 'near': self.camera.get_near(),
+ 'far': self.camera.get_far(),
+ 'width': self.camera.get_width(),
+ 'height': self.camera.get_height(),
+ 'fov': self.camera.get_fovy(),
+ 'camera_matrix': self.camera.get_camera_matrix(),
+ 'projection_matrix': self.camera.get_projection_matrix(),
+ 'model_matrix': self.camera.get_model_matrix(),
+ 'mat44': self.mat44,
+ }
+
+ # return camera parameters
+ def get_metadata_json(self):
+ return {
+ 'dist': self.dist,
+ 'theta': self.theta,
+ 'phi': self.phi,
+ 'near': self.camera.get_near(),
+ 'far': self.camera.get_far(),
+ 'width': self.camera.get_width(),
+ 'height': self.camera.get_height(),
+ 'fov': self.camera.get_fovy(),
+ 'camera_matrix': self.camera.get_camera_matrix().tolist(),
+ 'projection_matrix': self.camera.get_projection_matrix().tolist(),
+ 'model_matrix': self.camera.get_model_matrix().tolist(),
+ 'mat44': self.mat44.tolist(),
+ }
diff --git a/envs/env.py b/envs/env.py
new file mode 100644
index 0000000..d71fbd4
--- /dev/null
+++ b/envs/env.py
@@ -0,0 +1,628 @@
+"""
+Modified from https://github.com/warshallrho/VAT-Mart/blob/main/code/env.py
+"""
+
+from __future__ import division
+import sapien.core as sapien
+from sapien.core import Pose, SceneConfig, OptifuserConfig, ArticulationJointType
+import numpy as np
+import trimesh
+
+
+def process_angle_limit(x):
+ if np.isneginf(x):
+ x = -10
+ if np.isinf(x):
+ x = 10
+ return x
+
+def get_random_number(l, r):
+ return np.random.rand() * (r - l) + l
+
+
+class ContactError(Exception):
+ pass
+
+
+class SVDError(Exception):
+ pass
+
+
+class Env(object):
+
+ def __init__(self, flog=None, show_gui=True, render_rate=20, timestep=1/500, \
+ object_position_offset=0.0, succ_ratio=0.1):
+ self.current_step = 0
+
+ self.flog = flog
+ self.show_gui = show_gui
+ self.render_rate = render_rate
+ self.timestep = timestep
+ self.succ_ratio = succ_ratio
+ self.object_position_offset = object_position_offset
+
+ # engine and renderer
+ self.engine = sapien.Engine(0, 0.001, 0.005)
+
+ render_config = OptifuserConfig()
+ render_config.shadow_map_size = 8192
+ render_config.shadow_frustum_size = 10
+ render_config.use_shadow = False
+ render_config.use_ao = True
+
+ self.renderer = sapien.OptifuserRenderer(config=render_config)
+ self.renderer.enable_global_axes(False)
+
+ self.engine.set_renderer(self.renderer)
+
+ # GUI
+ self.window = False
+ if show_gui:
+ self.renderer_controller = sapien.OptifuserController(self.renderer)
+ self.renderer_controller.set_camera_position(-3.0+object_position_offset, 1.0, 3.0)
+ self.renderer_controller.set_camera_rotation(-0.4, -0.8)
+
+ # scene
+ scene_config = SceneConfig()
+ scene_config.gravity = [0, 0, 0]
+ scene_config.solver_iterations = 20
+ scene_config.enable_pcm = False
+ scene_config.sleep_threshold = 0.0
+
+ self.scene = self.engine.create_scene(config=scene_config)
+ if show_gui:
+ self.renderer_controller.set_current_scene(self.scene)
+
+ self.scene.set_timestep(timestep)
+
+ # add lights
+ self.scene.set_ambient_light([0.5, 0.5, 0.5])
+ self.scene.set_shadow_light([0, 1, -1], [0.5, 0.5, 0.5])
+ self.scene.add_point_light([1+object_position_offset, 2, 2], [1, 1, 1])
+ self.scene.add_point_light([1+object_position_offset, -2, 2], [1, 1, 1])
+ self.scene.add_point_light([-1+object_position_offset, 0, 1], [1, 1, 1])
+
+ # default Nones
+ self.object = None
+ self.object_target_joint = None
+
+ # check contact
+ self.check_contact = False
+ self.contact_error = False
+
+ # visual objects
+ self.visual_builder = self.scene.create_actor_builder()
+ self.visual_objects = dict()
+
+ def set_controller_camera_pose(self, x, y, z, yaw, pitch):
+ self.renderer_controller.set_camera_position(x, y, z)
+ self.renderer_controller.set_camera_rotation(yaw, pitch)
+ self.renderer_controller.render()
+
+ def load_object(self, urdf, material, state='closed', target_part_id=-1, target_part_idx=-1, scale=1.0):
+ # NOTE: set target_part_idx only set other joints to nearly closed, will not track the target joint
+ loader = self.scene.create_urdf_loader()
+ loader.scale = scale
+ self.object = loader.load(urdf, {"material": material})
+ pose = Pose([self.object_position_offset, 0, 0], [1, 0, 0, 0])
+ self.object.set_root_pose(pose)
+
+ # compute link actor information
+ self.all_link_ids = [l.get_id() for l in self.object.get_links()]
+ self.all_link_names = [l.get_name() for l in self.object.get_links()]
+ self.movable_link_ids = []
+ self.movable_joint_idxs = []
+ self.movable_link_joint_types = []
+ self.movable_link_joint_names = []
+
+ for j_idx, j in enumerate(self.object.get_joints()):
+ if j.get_dof() == 1:
+ if j.type == ArticulationJointType.REVOLUTE:
+ self.movable_link_joint_types.append(0)
+ if j.type == ArticulationJointType.PRISMATIC:
+ self.movable_link_joint_types.append(1)
+ self.movable_link_joint_names.append(j.get_name())
+
+ self.movable_link_ids.append(j.get_child_link().get_id())
+ self.movable_joint_idxs.append(j_idx)
+ if self.flog is not None:
+ self.flog.write('All Actor Link IDs: %s\n' % str(self.all_link_ids))
+ self.flog.write('All Movable Actor Link IDs: %s\n' % str(self.movable_link_ids))
+
+ # set joint property
+ for joint in self.object.get_joints():
+ joint.set_drive_property(stiffness=0, damping=10)
+
+ # set initial qpos
+ joint_angles = []
+ joint_abs_angles = []
+ self.joint_angles_lower = []
+ self.joint_angles_upper = []
+ target_part_joint_idx = -1
+ joint_idx = 0
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ if j.get_child_link().get_id() == target_part_id:
+ target_part_joint_idx = len(joint_angles)
+ l = process_angle_limit(j.get_limits()[0, 0])
+ self.joint_angles_lower.append(float(l))
+ r = process_angle_limit(j.get_limits()[0, 1])
+ self.joint_angles_upper.append(float(r))
+ if state == 'closed':
+ joint_angles.append(float(l))
+ elif state == 'open':
+ joint_angles.append(float(r))
+ elif state == 'random-middle':
+ joint_angles.append(float(get_random_number(l, r)))
+ elif state == 'random-closed-middle':
+ if np.random.random() < 0.5:
+ joint_angles.append(float(get_random_number(l, r)))
+ else:
+ joint_angles.append(float(l))
+ elif state == 'random-middle-middle':
+ if joint_idx == target_part_idx:
+ joint_angles.append(float(get_random_number(l + 0.15*(r-l), r - 0.15*(r-l))))
+ else:
+ joint_angles.append(float(get_random_number(l, l + 0.0*(r-l))))
+ else:
+ raise ValueError('ERROR: object init state %s unknown!' % state)
+ joint_abs_angles.append((joint_angles[-1]-l)/(r-l))
+ joint_idx += 1
+
+ self.object.set_qpos(joint_angles)
+ if target_part_id >= 0:
+ return joint_angles, target_part_joint_idx, joint_abs_angles
+ return joint_angles, joint_abs_angles
+
+ def load_real_object(self, urdf, material, joint_angles=None):
+ loader = self.scene.create_urdf_loader()
+ self.object = loader.load(urdf, {"material": material})
+ pose = Pose([self.object_position_offset, 0, 0], [1, 0, 0, 0])
+ self.object.set_root_pose(pose)
+
+ # compute link actor information
+ self.all_link_ids = [l.get_id() for l in self.object.get_links()]
+ self.movable_link_ids = []
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ self.movable_link_ids.append(j.get_child_link().get_id())
+ if self.flog is not None:
+ self.flog.write('All Actor Link IDs: %s\n' % str(self.all_link_ids))
+ self.flog.write('All Movable Actor Link IDs: %s\n' % str(self.movable_link_ids))
+
+ # set joint property
+ for joint in self.object.get_joints():
+ joint.set_drive_property(stiffness=0, damping=10)
+
+ if joint_angles is not None:
+ self.object.set_qpos(joint_angles)
+
+ return None
+
+ def update_and_set_joint_angles_all(self, state='closed'):
+ joint_angles = []
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ l = process_angle_limit(j.get_limits()[0, 0])
+ self.joint_angles_lower.append(float(l))
+ r = process_angle_limit(j.get_limits()[0, 1])
+ self.joint_angles_upper.append(float(r))
+ if state == 'closed':
+ joint_angles.append(float(l))
+ elif state == 'open':
+ joint_angles.append(float(r))
+ elif state == 'random-middle':
+ joint_angles.append(float(get_random_number(l, r)))
+ elif state == 'random-closed-middle':
+ if np.random.random() < 0.5:
+ joint_angles.append(float(get_random_number(l, r)))
+ else:
+ joint_angles.append(float(l))
+ else:
+ raise ValueError('ERROR: object init state %s unknown!' % state)
+ self.object.set_qpos(joint_angles)
+ return joint_angles
+
+ def get_target_part_axes_new(self, target_part_id):
+ joint_axes = None
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ if j.get_child_link().get_id() == target_part_id:
+ pos = j.get_global_pose()
+ mat = pos.to_transformation_matrix()
+ joint_axes = [float(-mat[0, 0]), float(-mat[1, 0]), float(mat[2, 0])]
+ if joint_axes is None:
+ raise ValueError('joint axes error!')
+
+ return joint_axes
+
+ def get_target_part_axes_dir_new(self, target_part_id):
+ joint_axes = self.get_target_part_axes_new(target_part_id=target_part_id)
+ axes_dir = -1
+ for idx_axes_dim in range(3):
+ if abs(joint_axes[idx_axes_dim]) > 0.1:
+ axes_dir = idx_axes_dim
+ return axes_dir
+
+ def get_target_part_origins_new(self, target_part_id):
+ joint_origins = None
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ if j.get_child_link().get_id() == target_part_id:
+ pos = j.get_global_pose()
+ joint_origins = pos.p.tolist()
+ if joint_origins is None:
+ raise ValueError('joint origins error!')
+
+ return joint_origins
+
+ def update_joint_angle(self, joint_angles, target_part_joint_idx, state, task_lower, push=True, pull=False, drawer=False):
+ if push:
+ if drawer:
+ l = max(self.joint_angles_lower[target_part_joint_idx], self.joint_angles_lower[target_part_joint_idx] + task_lower)
+ r = self.joint_angles_upper[target_part_joint_idx]
+ else:
+ l = max(self.joint_angles_lower[target_part_joint_idx], self.joint_angles_lower[target_part_joint_idx] + task_lower * np.pi / 180)
+ r = self.joint_angles_upper[target_part_joint_idx]
+ if pull:
+ if drawer:
+ l = self.joint_angles_lower[target_part_joint_idx]
+ r = self.joint_angles_upper[target_part_joint_idx] - task_lower
+ else:
+ l = self.joint_angles_lower[target_part_joint_idx]
+ r = self.joint_angles_upper[target_part_joint_idx] - task_lower * np.pi / 180
+ if state == 'closed':
+ joint_angles[target_part_joint_idx] = (float(l))
+ elif state == 'open':
+ joint_angles[target_part_joint_idx] = float(r)
+ elif state == 'random-middle':
+ joint_angles[target_part_joint_idx] = float(get_random_number(l, r))
+ elif state == 'random-middle-open':
+ joint_angles[target_part_joint_idx] = float(get_random_number(r * 0.8, r))
+ elif state == 'random-closed-middle':
+ if np.random.random() < 0.5:
+ joint_angles[target_part_joint_idx] = float(get_random_number(l, r))
+ else:
+ joint_angles[target_part_joint_idx] = float(l)
+ else:
+ raise ValueError('ERROR: object init state %s unknown!' % state)
+ return joint_angles
+
+ def set_object_joint_angles(self, joint_angles):
+ self.object.set_qpos(joint_angles)
+
+ def set_target_object_part_actor_id(self, actor_id):
+ if self.flog is not None:
+ self.flog.write('Set Target Object Part Actor ID: %d\n' % actor_id)
+ self.target_object_part_actor_id = actor_id
+ self.non_target_object_part_actor_id = list(set(self.all_link_ids) - set([actor_id]))
+
+ # get the link handler
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ if j.get_child_link().get_id() == actor_id:
+ self.target_object_part_actor_link = j.get_child_link()
+
+ # moniter the target joint
+ idx = 0
+ for j in self.object.get_joints():
+ if j.get_dof() == 1:
+ if j.get_child_link().get_id() == actor_id:
+ self.target_object_part_joint_id = idx
+ self.target_object_part_joint_type = j.type
+ idx += 1
+
+ def get_object_qpos(self):
+ return self.object.get_qpos()
+
+ def get_target_part_qpos(self):
+ qpos = self.object.get_qpos()
+ return float(qpos[self.target_object_part_joint_id])
+
+ def get_target_part_state(self):
+ qpos = self.object.get_qpos()
+ return float(qpos[self.target_object_part_joint_id]) - self.joint_angles_lower[self.target_object_part_joint_id]
+
+ def get_target_part_pose(self):
+ return self.target_object_part_actor_link.get_pose()
+
+ def start_checking_contact(self, robot_hand_actor_id, robot_gripper_actor_ids, strict):
+ self.check_contact = True
+ self.check_contact_strict = strict
+ self.first_timestep_check_contact = True
+ self.robot_hand_actor_id = robot_hand_actor_id
+ self.robot_gripper_actor_ids = robot_gripper_actor_ids
+ self.contact_error = False
+
+ def end_checking_contact(self, robot_hand_actor_id, robot_gripper_actor_ids, strict):
+ self.check_contact = False
+ self.check_contact_strict = strict
+ self.first_timestep_check_contact = False
+ self.robot_hand_actor_id = robot_hand_actor_id
+ self.robot_gripper_actor_ids = robot_gripper_actor_ids
+
+ def get_material(self, static_friction, dynamic_friction, restitution):
+ return self.engine.create_physical_material(static_friction, dynamic_friction, restitution)
+
+ def render(self):
+ if self.show_gui and (not self.window):
+ self.window = True
+ self.renderer_controller.show_window()
+ self.scene.update_render()
+ if self.show_gui and (self.current_step % self.render_rate == 0):
+ self.renderer_controller.render()
+
+ def step(self):
+ self.current_step += 1
+ self.scene.step()
+ if self.check_contact:
+ if not self.check_contact_is_valid():
+ raise ContactError()
+
+ # check the first contact: only gripper links can touch the target object part link
+ def check_contact_is_valid(self):
+ self.contacts = self.scene.get_contacts()
+ contact = False; valid = False;
+ for c in self.contacts:
+ aid1 = c.actor1.get_id()
+ aid2 = c.actor2.get_id()
+ has_impulse = False
+ for p in c.points:
+ if abs(p.impulse @ p.impulse) > 1e-4:
+ has_impulse = True
+ break
+ if has_impulse:
+ if (aid1 in self.robot_gripper_actor_ids and aid2 == self.target_object_part_actor_id) or \
+ (aid2 in self.robot_gripper_actor_ids and aid1 == self.target_object_part_actor_id):
+ contact, valid = True, True
+ if (aid1 in self.robot_gripper_actor_ids and aid2 in self.non_target_object_part_actor_id) or \
+ (aid2 in self.robot_gripper_actor_ids and aid1 in self.non_target_object_part_actor_id):
+ if self.check_contact_strict:
+ self.contact_error = True
+ return False
+ else:
+ contact, valid = True, True
+ if (aid1 == self.robot_hand_actor_id or aid2 == self.robot_hand_actor_id):
+ if self.check_contact_strict:
+ self.contact_error = True
+ return False
+ else:
+ contact, valid = True, True
+ # starting pose should have no collision at all
+ if (aid1 in self.robot_gripper_actor_ids or aid1 == self.robot_hand_actor_id or \
+ aid2 in self.robot_gripper_actor_ids or aid2 == self.robot_hand_actor_id) and self.first_timestep_check_contact:
+ self.contact_error = True
+ return False
+
+ self.first_timestep_check_contact = False
+ if contact and valid:
+ self.check_contact = False
+ return True
+
+ def check_contact_right(self):
+ contacts = self.scene.get_contacts()
+ if len(contacts) < len(self.robot_gripper_actor_ids):
+ # print("no enough contacts")
+ return False
+ # for c in contacts:
+ # aid1 = c.actor1.get_id()
+ # aid2 = c.actor2.get_id()
+ # if (aid1 in self.robot_gripper_actor_ids and aid2 == self.target_object_part_actor_id) or \
+ # (aid2 in self.robot_gripper_actor_ids and aid1 == self.target_object_part_actor_id):
+ # print("right")
+ # pass
+ # elif (aid1 in self.robot_gripper_actor_ids and aid2 in self.non_target_object_part_actor_id) or \
+ # (aid2 in self.robot_gripper_actor_ids and aid1 in self.non_target_object_part_actor_id):
+ # print("unright")
+ # right = False
+ # elif (aid1 == self.robot_hand_actor_id or aid2 == self.robot_hand_actor_id):
+ # print("hand contact")
+ # right = False
+ # elif (aid1 in self.robot_gripper_actor_ids and aid2 in self.robot_gripper_actor_ids):
+ # print("also non-successful grasp")
+ # right = False
+ # else:
+ # print(c.actor1.get_name(), c.actor2.get_name())
+ right_aids = []
+ for c in contacts:
+ aid1 = c.actor1.get_id()
+ aid2 = c.actor2.get_id()
+ if (aid1 in self.robot_gripper_actor_ids and aid2 == self.target_object_part_actor_id):
+ right_aids.append(aid1)
+ elif (aid2 in self.robot_gripper_actor_ids and aid1 == self.target_object_part_actor_id):
+ right_aids.append(aid2)
+ else:
+ pass
+ right = (set(right_aids) == set(self.robot_gripper_actor_ids))
+ return right
+
+ def close_render(self):
+ if self.window:
+ self.renderer_controller.hide_window()
+ self.window = False
+
+ def wait_to_start(self):
+ print('press q to start\n')
+ while not self.renderer_controller.should_quit:
+ self.scene.update_render()
+ if self.show_gui:
+ self.renderer_controller.render()
+
+ def close(self):
+ if self.show_gui:
+ self.renderer_controller.set_current_scene(None)
+ self.scene = None
+
+ def get_global_mesh(self, obj):
+ final_vs = [];
+ final_fs = [];
+ vid = 0;
+ for l in obj.get_links():
+ vs = []
+ for s in l.get_collision_shapes():
+ v = np.array(s.convex_mesh_geometry.vertices, dtype=np.float32)
+ f = np.array(s.convex_mesh_geometry.indices, dtype=np.uint32).reshape(-1, 3)
+ vscale = s.convex_mesh_geometry.scale
+ v[:, 0] *= vscale[0];
+ v[:, 1] *= vscale[1];
+ v[:, 2] *= vscale[2];
+ ones = np.ones((v.shape[0], 1), dtype=np.float32)
+ v_ones = np.concatenate([v, ones], axis=1)
+ transmat = s.pose.to_transformation_matrix()
+ v = (v_ones @ transmat.T)[:, :3]
+ vs.append(v)
+ final_fs.append(f + vid)
+ vid += v.shape[0]
+ if len(vs) > 0:
+ vs = np.concatenate(vs, axis=0)
+ ones = np.ones((vs.shape[0], 1), dtype=np.float32)
+ vs_ones = np.concatenate([vs, ones], axis=1)
+ transmat = l.get_pose().to_transformation_matrix()
+ vs = (vs_ones @ transmat.T)[:, :3]
+ final_vs.append(vs)
+ final_vs = np.concatenate(final_vs, axis=0)
+ final_fs = np.concatenate(final_fs, axis=0)
+ return final_vs, final_fs
+
+ def get_part_mesh(self, obj, part_id):
+ final_vs = [];
+ final_fs = [];
+ vid = 0;
+ for l in obj.get_links():
+ l_id = l.get_id()
+ if l_id != part_id:
+ continue
+ vs = []
+ for s in l.get_collision_shapes():
+ v = np.array(s.convex_mesh_geometry.vertices, dtype=np.float32)
+ f = np.array(s.convex_mesh_geometry.indices, dtype=np.uint32).reshape(-1, 3)
+ vscale = s.convex_mesh_geometry.scale
+ v[:, 0] *= vscale[0];
+ v[:, 1] *= vscale[1];
+ v[:, 2] *= vscale[2];
+ ones = np.ones((v.shape[0], 1), dtype=np.float32)
+ v_ones = np.concatenate([v, ones], axis=1)
+ transmat = s.pose.to_transformation_matrix()
+ v = (v_ones @ transmat.T)[:, :3]
+ vs.append(v)
+ final_fs.append(f + vid)
+ vid += v.shape[0]
+ if len(vs) > 0:
+ vs = np.concatenate(vs, axis=0)
+ ones = np.ones((vs.shape[0], 1), dtype=np.float32)
+ vs_ones = np.concatenate([vs, ones], axis=1)
+ transmat = l.get_pose().to_transformation_matrix()
+ vs = (vs_ones @ transmat.T)[:, :3]
+ final_vs.append(vs)
+ final_vs = np.concatenate(final_vs, axis=0)
+ final_fs = np.concatenate(final_fs, axis=0)
+ return final_vs, final_fs
+
+ def sample_pc(self, v, f, n_points=4096):
+ mesh = trimesh.Trimesh(vertices=v, faces=f)
+ points, __ = trimesh.sample.sample_surface(mesh=mesh, count=n_points)
+ return points
+
+ def check_drawer(self):
+ for j in self.object.get_joints():
+ if j.get_dof() == 1 and (j.type == ArticulationJointType.PRISMATIC):
+ return True
+ return False
+
+ def add_point_visual(self, point:np.ndarray, color=[1, 0, 0], radius=0.04, name='point_visual'):
+ self.visual_builder.add_sphere_visual(pose=Pose(p=point), radius=radius, color=color, name=name)
+ point_visual = self.visual_builder.build_static(name=name)
+ assert name not in self.visual_objects.keys()
+ self.visual_objects[name] = point_visual
+
+ def add_line_visual(self, point1:np.ndarray, point2:np.ndarray, color=[1, 0, 0], width=0.03, name='line_visual'):
+ direction = point2 - point1
+ direction = direction / np.linalg.norm(direction)
+ rotation = np.zeros((3, 3))
+ temp2 = np.cross(direction, np.array([1., 0., 0.]))
+ if np.linalg.norm(temp2) < 1e-6:
+ temp1 = np.cross(np.array([0., 1., 0.]), direction)
+ temp1 /= np.linalg.norm(temp1)
+ temp2 = np.cross(direction, temp1)
+ temp2 /= np.linalg.norm(temp2)
+ else:
+ temp2 /= np.linalg.norm(temp2)
+ temp1 = np.cross(temp2, direction)
+ temp1 /= np.linalg.norm(temp1)
+ rotation[:, 0] = temp1
+ rotation[:, 1] = temp2
+ rotation[:, 2] = direction
+ pose_transformation = np.eye(4)
+ pose_transformation[:3, 3] = (point1+point2)/2
+ pose_transformation[:3, :3] = rotation
+ pose = Pose().from_transformation_matrix(pose_transformation)
+ size = [width/2, width/2, np.linalg.norm(point1 - point2)/2]
+ self.visual_builder.add_box_visual(pose=pose, size=size, color=color, name=name)
+ line_visual = self.visual_builder.build_static(name=name)
+ assert name not in self.visual_objects.keys()
+ self.visual_objects[name] = line_visual
+
+ def add_grasp_visual(self, grasp_width, grasp_depth, grasp_translation, grasp_rotation, affordance=0.5, name='grasp_visual'):
+ finger_width = 0.004
+ tail_length = 0.04
+ depth_base = 0.02
+ gg_width = grasp_width
+ gg_depth = grasp_depth
+ gg_translation = grasp_translation
+ gg_rotation = grasp_rotation
+
+ left = np.zeros((2, 3))
+ left[0] = np.array([-depth_base - finger_width, -gg_width / 2, 0])
+ left[1] = np.array([gg_depth, -gg_width / 2, 0])
+
+ right = np.zeros((2, 3))
+ right[0] = np.array([-depth_base - finger_width, gg_width / 2, 0])
+ right[1] = np.array([gg_depth, gg_width / 2, 0])
+
+ bottom = np.zeros((2, 3))
+ bottom[0] = np.array([-depth_base - finger_width, -gg_width / 2, 0])
+ bottom[1] = np.array([-depth_base - finger_width, gg_width / 2, 0])
+
+ tail = np.zeros((2, 3))
+ tail[0] = np.array([-(tail_length + finger_width + depth_base), 0, 0])
+ tail[1] = np.array([-(finger_width + depth_base), 0, 0])
+
+ vertices = np.vstack([left, right, bottom, tail])
+ vertices = np.dot(gg_rotation, vertices.T).T + gg_translation
+
+ if affordance < 0.5:
+ color = [1, 2*affordance, 0]
+ elif affordance == 1.0:
+ color = [0, 0, 1]
+ else:
+ color = [-2*affordance+2, 1, 0]
+ self.add_line_visual(vertices[0], vertices[1], color, width=0.005, name=name+'_left')
+ self.add_line_visual(vertices[2], vertices[3], color, width=0.005, name=name+'_right')
+ self.add_line_visual(vertices[4], vertices[5], color, width=0.005, name=name+'_bottom')
+ self.add_line_visual(vertices[6], vertices[7], color, width=0.005, name=name+'_tail')
+
+ def add_frame_visual(self):
+ self.add_line_visual(np.array([0, 0, 0]), np.array([1, 0, 0]), [1, 0, 0], width=0.005, name='frame_x')
+ self.add_line_visual(np.array([0, 0, 0]), np.array([0, 1, 0]), [0, 1, 0], width=0.005, name='frame_y')
+ self.add_line_visual(np.array([0, 0, 0]), np.array([0, 0, 1]), [0, 0, 1], width=0.005, name='frame_z')
+
+ def remove_visual(self, name):
+ self.visual_builder.remove_visual_at(self.visual_objects[name].get_id())
+ self.scene.remove_actor(self.visual_objects[name])
+ del self.visual_objects[name]
+
+ def remove_grasp_visual(self, name):
+ self.remove_visual(name+'_left')
+ self.remove_visual(name+'_right')
+ self.remove_visual(name+'_bottom')
+ self.remove_visual(name+'_tail')
+
+ def remove_frame_visual(self):
+ self.remove_visual('frame_x')
+ self.remove_visual('frame_y')
+ self.remove_visual('frame_z')
+
+ def remove_all_visuals(self):
+ names = list(self.visual_objects.keys())
+ for name in names:
+ self.remove_visual(name)
diff --git a/envs/real_camera.py b/envs/real_camera.py
new file mode 100644
index 0000000..f2a0218
--- /dev/null
+++ b/envs/real_camera.py
@@ -0,0 +1,210 @@
+import pyrealsense2 as rs
+import numpy as np
+import cv2
+
+class CameraL515(object):
+ def __init__(self):
+ self.pipeline = rs.pipeline()
+ self.config = rs.config()
+ self.config.enable_stream(rs.stream.depth, 1024, 768, rs.format.z16, 30)
+ self.config.enable_stream(rs.stream.color, 1280, 720, rs.format.bgr8, 30)
+ self.align_to = rs.stream.color
+ self.align = rs.align(self.align_to)
+
+ self.pipeline_profile = self.pipeline.start(self.config)
+ try:
+ self.device = self.pipeline_profile.get_device()
+ self.mtx = self.getIntrinsics()
+
+ self.hole_filling = rs.hole_filling_filter()
+
+ align_to = rs.stream.color
+ self.align = rs.align(align_to)
+
+ # camera init warm up
+ i = 60
+ while i>0:
+ frames = self.pipeline.wait_for_frames()
+ aligned_frames = self.align.process(frames)
+ depth_frame = aligned_frames.get_depth_frame()
+ color_frame = aligned_frames.get_color_frame()
+ # pdb.set_trace()
+ color_frame.get_profile().as_video_stream_profile().get_intrinsics()
+ if not depth_frame or not color_frame:
+ continue
+ depth_image = np.asanyarray(depth_frame.get_data())
+ color_image = np.asanyarray(color_frame.get_data())
+ i -= 1
+ except:
+ self.__del__()
+ raise
+
+ def getIntrinsics(self):
+ frames = self.pipeline.wait_for_frames()
+ aligned_frames = self.align.process(frames)
+ color_frame = aligned_frames.get_color_frame()
+ intrinsics = color_frame.get_profile().as_video_stream_profile().get_intrinsics()
+ mtx = [intrinsics.width,intrinsics.height,intrinsics.ppx,intrinsics.ppy,intrinsics.fx,intrinsics.fy]
+ camIntrinsics = np.array([[mtx[4],0,mtx[2]],
+ [0,mtx[5],mtx[3]],
+ [0,0,1.]])
+ return camIntrinsics
+
+ def get_data(self, hole_filling=False):
+ while True:
+ frames = self.pipeline.wait_for_frames()
+ aligned_frames = self.align.process(frames)
+ depth_frame = aligned_frames.get_depth_frame()
+ if hole_filling:
+ depth_frame = self.hole_filling.process(depth_frame)
+ color_frame = aligned_frames.get_color_frame()
+ if not depth_frame or not color_frame:
+ continue
+ depth_image = np.asanyarray(depth_frame.get_data())
+ color_image = np.asanyarray(color_frame.get_data())
+ break
+ return color_image, depth_image
+
+ def get_data1(self, hole_filling=False):
+ while True:
+ frames = self.pipeline.wait_for_frames()
+ # aligned_frames = self.align.process(frames)
+ # depth_frame = aligned_frames.get_depth_frame()
+ # if hole_filling:
+ # depth_frame = self.hole_filling.process(depth_frame)
+ # color_frame = aligned_frames.get_color_frame()
+ depth_frame = frames.get_depth_frame()
+ color_frame = frames.get_color_frame()
+ if not depth_frame or not color_frame:
+ continue
+ colorizer = rs.colorizer()
+ depth_image = np.asanyarray(colorizer.colorize(depth_frame).get_data())
+ # depth_image = np.asanyarray(depth_frame.get_data())
+ color_image = np.asanyarray(color_frame.get_data())
+ break
+ return color_image, depth_image
+
+ def inpaint(self, img, missing_value=0):
+ '''
+ pip opencv-python == 3.4.8.29
+ :param image:
+ :param roi: [x0,y0,x1,y1]
+ :param missing_value:
+ :return:
+ '''
+ # cv2 inpainting doesn't handle the border properly
+ # https://stackoverflow.com/questions/25974033/inpainting-depth-map-still-a-black-image-border
+ img = cv2.copyMakeBorder(img, 1, 1, 1, 1, cv2.BORDER_DEFAULT)
+ mask = (img == missing_value).astype(np.uint8)
+
+ # Scale to keep as float, but has to be in bounds -1:1 to keep opencv happy.
+ scale = np.abs(img).max()
+ img = img.astype(np.float32) / scale # Has to be float32, 64 not supported.
+ img = cv2.inpaint(img, mask, 1, cv2.INPAINT_NS)
+
+ # Back to original size and value range.
+ img = img[1:-1, 1:-1]
+ img = img * scale
+ return img
+
+ def getXYZRGB(self,color, depth, robot_pose,camee_pose,camIntrinsics,inpaint=True,depth_scale=None):
+ '''
+ :param color:
+ :param depth:
+ :param robot_pose: array 4*4
+ :param camee_pose: array 4*4
+ :param camIntrinsics: array 3*3
+ :param inpaint: bool
+ :param depth_scale: float, change to meter unit
+ :return: xyzrgb
+ '''
+ heightIMG, widthIMG, _ = color.shape
+ # pdb.set_trace()
+ # heightIMG = 720
+ # widthIMG = 1280
+ # depthImg = depth / 10000.
+ assert depth_scale is not None
+ depthImg = depth * depth_scale
+ # depthImg = depth
+ if inpaint:
+ depthImg = self.inpaint(depthImg)
+ robot_pose = np.dot(robot_pose, camee_pose)
+
+ [pixX, pixY] = np.meshgrid(np.arange(widthIMG), np.arange(heightIMG))
+ camX = (pixX - camIntrinsics[0][2]) * depthImg / camIntrinsics[0][0]
+ camY = (pixY - camIntrinsics[1][2]) * depthImg / camIntrinsics[1][1]
+ camZ = depthImg
+
+ camPts = [camX.reshape(camX.shape + (1,)), camY.reshape(camY.shape + (1,)), camZ.reshape(camZ.shape + (1,))]
+ camPts = np.concatenate(camPts, 2)
+ camPts = camPts.reshape((camPts.shape[0] * camPts.shape[1], camPts.shape[2])) # shape = (heightIMG*widthIMG, 3)
+ worldPts = np.dot(robot_pose[:3, :3], camPts.transpose()) + robot_pose[:3, 3].reshape(3,
+ 1) # shape = (3, heightIMG*widthIMG)
+ rgb = color.reshape((-1, 3)) / 255.
+ rgb[:, [0, 2]] = rgb[:, [2, 0]]
+ xyzrgb = np.hstack((worldPts.T, rgb))
+ # xyzrgb = self.getleft(xyzrgb)
+ return xyzrgb
+
+ def getleft(self, obj1):
+ index = np.bitwise_and(obj1[:, 0] < 1.2, obj1[:, 0] > 0.2)
+ index = np.bitwise_and(obj1[:, 1] < 0.5, index)
+ index = np.bitwise_and(obj1[:, 1] > -0.5, index)
+ # index = np.bitwise_and(obj1[:, 2] > -0.1, index)
+ index = np.bitwise_and(obj1[:, 2] > 0.24, index)
+ index = np.bitwise_and(obj1[:, 2] < 0.6, index)
+ return obj1[index]
+
+
+ def __del__(self):
+ self.pipeline.stop()
+
+
+def vis_pc(xyzrgb):
+ import open3d as o3d
+ camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ pc1 = o3d.geometry.PointCloud()
+ pc1.points = o3d.utility.Vector3dVector(xyzrgb[:, :3])
+ pc1.colors = o3d.utility.Vector3dVector(xyzrgb[:, 3:])
+ o3d.visualization.draw_geometries([camera_frame, pc1])
+
+
+if __name__ == "__main__":
+ print("initialize camera")
+ cam = CameraL515()
+ i = 0
+ while True:
+ print(f"{i}th")
+ color, depth = cam.get_data(hole_filling=False)
+
+ depth_sensor = cam.pipeline_profile.get_device().first_depth_sensor()
+ depth_scale = depth_sensor.get_depth_scale()
+
+ # xyz in meter, rgb in [0, 1]
+ xyzrgb = cam.getXYZRGB(color, depth, np.identity(4), np.identity(4), cam.getIntrinsics(), inpaint=False, depth_scale=depth_scale)
+ # xyzrgb = xyzrgb[xyzrgb[:, 2] <= 1.5, :]
+ print(np.mean(xyzrgb[:, 2]))
+ vis_pc(xyzrgb)
+
+ cv2.imshow('color', color)
+ while True:
+ if cv2.getWindowProperty('color', cv2.WND_PROP_VISIBLE) <= 0:
+ break
+ cv2.waitKey(1)
+ cv2.destroyAllWindows()
+
+ cmd = input("whether save? (y/n): ")
+ if cmd == 'y':
+ cv2.imwrite(f"rgb_{i}.png", color)
+ np.savez(f"xyzrgb_{i}.npz", point_cloud=xyzrgb[:, :3], rgb=xyzrgb[:, 3:])
+ i += 1
+ elif cmd == 'n':
+ cmd = input("whether quit? (y/n): ")
+ if cmd == 'y':
+ break
+ elif cmd == 'n':
+ pass
+ else:
+ raise ValueError
+ else:
+ raise ValueError
diff --git a/envs/real_robot.py b/envs/real_robot.py
new file mode 100644
index 0000000..d8d5aee
--- /dev/null
+++ b/envs/real_robot.py
@@ -0,0 +1,169 @@
+import time
+import numpy as np
+import transformations as tf
+from frankx import Affine, JointMotion, Robot, Waypoint, WaypointMotion, Gripper, LinearRelativeMotion, LinearMotion, ImpedanceMotion
+
+
+class Panda():
+ def __init__(self,host='172.16.0.2'):
+ self.robot = Robot(host)
+ self.gripper = Gripper(host)
+ self.setGripper(20,0.1)
+ self.max_gripper_width = 0.08
+ self.robot.set_default_behavior()
+ self.robot.recover_from_errors()
+ # Reduce the acceleration and velocity dynamic
+ self.robot.set_dynamic_rel(0.2)
+ # self.robot.set_dynamic_rel(0.05)
+
+ # self.robot.velocity_rel = 0.1
+ # self.robot.acceleration_rel = 0.02
+ # self.robot.jerk_rel = 0.01
+
+ self.joint_tolerance = 0.01
+ # state = self.robot.read_once()
+ # print('\nPose: ', self.robot.current_pose())
+ # print('O_TT_E: ', state.O_T_EE)
+ # print('Joints: ', state.q)
+ # print('Elbow: ', state.elbow)
+ # pdb.set_trace()
+ self.in_impedance_control = False
+
+ def setGripper(self,force=20.0, speed=0.02):
+ self.gripper.gripper_speed = speed # m/s
+ self.gripper.gripper_force = force # N
+
+ def gripper_close(self) -> bool:
+ # can be used to grasp
+ is_graspped = self.gripper.clamp()
+ return is_graspped
+
+ def gripper_open(self) -> None:
+ self.gripper.open()
+
+ def gripper_release(self, width:float) -> None:
+ # can be used to release after grasping
+ self.gripper.release(min(max(width, 0.0), self.max_gripper_width))
+
+ def move_gripper(self, width:float) -> None:
+ self.gripper.move(min(max(width, 0.0), self.max_gripper_width), self.gripper.gripper_speed) # m
+
+ def read_gripper(self) -> float:
+ return self.gripper.width()
+
+ def is_grasping(self) -> bool:
+ return self.gripper.is_grasping()
+
+ def moveJoint(self,joint,moveTarget=True):
+ assert len(joint)==7, "panda DOF is 7"
+ if not self.in_impedance_control:
+ self.robot.move(JointMotion(joint))
+ else:
+ raise NotImplementedError
+ # while moveTarget:
+ # current = self.robot.read_once().q
+ # if all([np.abs(current[j] - joint[j]) None:
+ # joint = [-0.2918438928353856, -0.970780364569858, 0.10614118311070558, -1.3263296233896118, 0.28714199085241543, 1.4429661556967983, 0.8502855184922615] # microwave: pad + safe (rotate)
+ # joint = [-0.32146529861813555, -0.6174831717455548, 0.08796035485936884, -0.8542264670393085, 0.2846642250021548, 1.2692416279845777, 0.7918693021188179] # refrigerator: storagefurniture
+ # joint = [0.07228589984826875, -0.856545356798933, -0.005984785356738588, -1.446693722022207, -0.0739646362066269, 1.5132004619969288, 0.8178283093992272] # safe: pad + microwave
+ # joint = [-0.2249300478901436, -0.8004290411025673, 0.10279602963647609, -1.2284506426734476, 0.22189371273337696, 1.3787900806797873, 0.7783415511498849] # storagefurniture: microwave
+ # joint = [-0.2249300478901436, -0.8004290411025673, 0.10279602963647609, -1.2284506426734476, 0.22189371273337696, 1.3787900806797873, 0.7783415511498849] # drawer: microwave
+ joint = [-0.23090655681806208, -0.7368697004085705, 0.06469194421473157, -1.5633050220115945, 0.06594510726133981, 1.6454856337730452, 0.7169523042954654] # washingmachine: pad + microwave
+ self.moveJoint(joint)
+
+ def readPose(self):
+ if not self.in_impedance_control:
+ pose = np.array(self.robot.read_once().O_T_EE).reshape(4, 4).T # EE2robot, gripper pose
+ else:
+ pose = np.array(self.impedance_motion.get_robotstate().O_T_EE).reshape(4, 4).T
+ return pose
+
+ def movePose(self, pose):
+ # gripper pose
+ # tf.euler_from_matrix(pose, axes='rzyx')
+ # R.from_euler('ZYX', [-1.560670, -0.745688, 1.922058]).as_matrix(), rpy->matrix
+ if not self.in_impedance_control:
+ tr = pose[:3, 3]
+ rot = tf.euler_from_matrix(pose, axes='rzyx')
+ motion = LinearMotion(Affine(tr[0], tr[1], tr[2], rot[0], rot[1], rot[2]))
+ self.robot.move(motion)
+ else:
+ tr = pose[:3, 3]
+ rot = tf.euler_from_matrix(pose, axes='rzyx')
+ self.impedance_motion.target = Affine(tr[0], tr[1], tr[2], rot[0], rot[1], rot[2])
+
+ def start_impedance_control(self, tr_stiffness=1000.0, rot_stiffness=20.0):
+ print("you need rebuild frankx to support this")
+ self.impedance_motion = ImpedanceMotion(tr_stiffness, rot_stiffness)
+ self.robot_thread = self.robot.move_async(self.impedance_motion)
+ time.sleep(0.5)
+ self.in_impedance_control = True
+
+ def end_impedance_control(self):
+ self.impedance_motion.finish()
+ self.robot_thread.join()
+ self.impedance_motion = None
+ self.robot_thread = None
+ self.in_impedance_control = False
+
+ def readWrench(self):
+ if not self.in_impedance_control:
+ wrench = np.array(self.robot.read_once().O_F_ext_hat_K) # in base
+ else:
+ wrench = np.array(self.impedance_motion.get_robotstate().O_F_ext_hat_K)
+ return wrench
+
+
+if __name__ == "__main__":
+ robot = Panda()
+ # test gripper
+ robot.gripper_close()
+ robot.gripper_open()
+ is_graspped = robot.gripper_close()
+ is_graspped = is_graspped and robot.is_grasping()
+ print("is_graspped:", is_graspped)
+ robot.move_gripper(0.05)
+ gripper_width = robot.read_gripper()
+ print("gripper width:", gripper_width)
+ # test arm
+ robot.homing()
+ joint = robot.readJoint()
+ print("current joint:", joint)
+ EE2robot = robot.readPose()
+ print("current pose:", EE2robot)
+ target_pose = EE2robot.copy()
+ target_pose[:3, 3] += np.array([0.05, 0.05, 0.05])
+ robot.movePose(target_pose)
+ joint = robot.readJoint()
+ print("current joint:", joint)
+ EE2robot = robot.readPose()
+ print("current pose:", EE2robot)
+ robot.start_impedance_control()
+ for i in range(10):
+ current_pose = robot.readPose()
+ print(i, current_pose)
+ target_pose = current_pose.copy()
+ target_pose[:3, 3] += np.array([0., 0.02, 0.])
+ robot.movePose(target_pose)
+ time.sleep(0.3)
+ EE2robot = robot.readPose()
+ print("current pose:", EE2robot)
+ robot.end_impedance_control()
+ robot.homing()
+ # pose = robot.readPose()
+ # np.save("pose_00.npy", pose)
+ # joints = robot.readJoint()
+ # np.save("joints_00.npy", joints)
diff --git a/envs/robot.py b/envs/robot.py
new file mode 100644
index 0000000..48bbe7d
--- /dev/null
+++ b/envs/robot.py
@@ -0,0 +1,354 @@
+"""
+Modified from https://github.com/warshallrho/VAT-Mart/blob/main/code/robots/panda_robot.py
+ Franka Panda Robot Arm
+ support panda.urdf, panda_gripper.urdf
+"""
+
+from __future__ import division
+import numpy as np
+from PIL import Image
+from sapien.core import Pose
+
+from .env import Env
+
+
+def rot2so3(rotation):
+ assert rotation.shape == (3, 3)
+ if np.isclose(rotation.trace(), 3):
+ return np.zeros(3), 1
+ if np.isclose(rotation.trace(), -1):
+ raise RuntimeError
+ theta = np.arccos((rotation.trace() - 1) / 2)
+ omega = 1 / 2 / np.sin(theta) * np.array(
+ [rotation[2, 1] - rotation[1, 2], rotation[0, 2] - rotation[2, 0], rotation[1, 0] - rotation[0, 1]]).T
+ return omega, theta
+
+def skew(vec):
+ return np.array([[0, -vec[2], vec[1]],
+ [vec[2], 0, -vec[0]],
+ [-vec[1], vec[0], 0]])
+
+def pose2exp_coordinate(pose):
+ """
+ Compute the exponential coordinate corresponding to the given SE(3) matrix
+ Note: unit twist is not a unit vector
+
+ Args:
+ pose: (4, 4) transformation matrix
+
+ Returns:
+ Unit twist: (6, ) vector represent the unit twist
+ Theta: scalar represent the quantity of exponential coordinate
+ """
+ omega, theta = rot2so3(pose[:3, :3])
+ ss = skew(omega)
+ inv_left_jacobian = np.eye(3, dtype=np.float) / theta - 0.5 * ss + (
+ 1.0 / theta - 0.5 / np.tan(theta / 2)) * ss @ ss
+ v = inv_left_jacobian @ pose[:3, 3]
+ return np.concatenate([omega, v]), theta
+
+def adjoint_matrix(pose):
+ adjoint = np.zeros([6, 6])
+ adjoint[:3, :3] = pose[:3, :3]
+ adjoint[3:6, 3:6] = pose[:3, :3]
+ adjoint[3:6, 0:3] = skew(pose[:3, 3]) @ pose[:3, :3]
+ return adjoint
+
+
+class Robot(object):
+ def __init__(self, env:Env, urdf, material, open_gripper=False, scale=1.0):
+ self.env = env
+ self.timestep = env.scene.get_timestep()
+
+ # load robot
+ loader = env.scene.create_urdf_loader()
+ loader.fix_root_link = True
+ loader.scale = scale
+ self.robot = loader.load(urdf, {"material": material})
+ self.robot.name = "robot"
+ self.max_gripper_width = 0.08
+ self.tcp2ee_length = 0.11
+ self.scale = scale
+
+ # hand (EE), two grippers, the rest arm joints (if any)
+ self.end_effector_index, self.end_effector = \
+ [(i, l) for i, l in enumerate(self.robot.get_links()) if l.name == 'panda_hand'][0]
+ self.root2ee = self.end_effector.get_pose().to_transformation_matrix() @ self.robot.get_root_pose().inv().to_transformation_matrix()
+ self.hand_actor_id = self.end_effector.get_id()
+ self.gripper_joints = [joint for joint in self.robot.get_joints() if
+ joint.get_name().startswith("panda_finger_joint")]
+ self.gripper_actor_ids = [joint.get_child_link().get_id() for joint in self.gripper_joints]
+ self.arm_joints = [joint for joint in self.robot.get_joints() if
+ joint.get_dof() > 0 and not joint.get_name().startswith("panda_finger")]
+ self.g2g = np.array([[0., 0., 1.], [0., 1., 0.], [-1., 0., 0.]])
+
+ # set drive joint property
+ for joint in self.arm_joints:
+ joint.set_drive_property(1000, 400)
+ for joint in self.gripper_joints:
+ joint.set_drive_property(200, 60)
+
+ # open/close the gripper at start
+ if open_gripper:
+ joint_angles = []
+ for j in self.robot.get_joints():
+ if j.get_dof() == 1:
+ if j.get_name().startswith("panda_finger_joint"):
+ joint_angles.append(self.max_gripper_width / 2.0 * scale)
+ else:
+ joint_angles.append(0)
+ self.robot.set_qpos(joint_angles)
+
+ def load_gripper(self, urdf, material, open_gripper=False, scale=1.0):
+ raise NotImplementedError
+ self.timestep = self.env.scene.get_timestep()
+
+ # load robot
+ loader = self.env.scene.create_urdf_loader()
+ loader.fix_root_link = True
+ loader.scale = scale
+ self.robot = loader.load(urdf, {"material": material})
+ self.robot.name = "robot"
+ self.max_gripper_width = 0.08
+ self.tcp2ee_length = 0.11
+ self.scale = scale
+
+ # hand (EE), two grippers, the rest arm joints (if any)
+ self.end_effector_index, self.end_effector = \
+ [(i, l) for i, l in enumerate(self.robot.get_links()) if l.name == 'panda_hand'][0]
+ self.root2ee = self.end_effector.get_pose().to_transformation_matrix() @ self.robot.get_root_pose().inv().to_transformation_matrix()
+ self.hand_actor_id = self.end_effector.get_id()
+ self.gripper_joints = [joint for joint in self.robot.get_joints() if
+ joint.get_name().startswith("panda_finger_joint")]
+ self.gripper_actor_ids = [joint.get_child_link().get_id() for joint in self.gripper_joints]
+ self.arm_joints = [joint for joint in self.robot.get_joints() if
+ joint.get_dof() > 0 and not joint.get_name().startswith("panda_finger")]
+
+ # set drive joint property
+ for joint in self.arm_joints:
+ joint.set_drive_property(1000, 400)
+ for joint in self.gripper_joints:
+ joint.set_drive_property(200, 60)
+
+ # open/close the gripper at start
+ if open_gripper:
+ joint_angles = []
+ for j in self.robot.get_joints():
+ if j.get_dof() == 1:
+ if j.get_name().startswith("panda_finger_joint"):
+ joint_angles.append(self.max_gripper_width / 2.0 * scale)
+ else:
+ joint_angles.append(0)
+ self.robot.set_qpos(joint_angles)
+
+ def compute_joint_velocity_from_twist(self, twist: np.ndarray) -> np.ndarray:
+ """
+ This function is a kinematic-level calculation which do not consider dynamics.
+ Pay attention to the frame of twist, is it spatial twist or body twist
+
+ Jacobian is provided for your, so no need to compute the velocity kinematics
+ ee_jacobian is the geometric Jacobian on account of only the joint of robot arm, not gripper
+ Jacobian in SAPIEN is defined as the derivative of spatial twist with respect to joint velocity
+
+ Args:
+ twist: (6,) vector to represent the twist
+
+ Returns:
+ (7, ) vector for the velocity of arm joints (not include gripper)
+
+ """
+ assert twist.size == 6
+ # Jacobian define in SAPIEN use twist (v, \omega) which is different from the definition in the slides
+ # So we perform the matrix block operation below
+ dense_jacobian = self.robot.compute_spatial_twist_jacobian() # (num_link * 6, dof())
+ ee_jacobian = np.zeros([6, self.robot.dof - 2])
+ ee_jacobian[:3, :] = dense_jacobian[self.end_effector_index * 6 - 3: self.end_effector_index * 6, :self.robot.dof - 2]
+ ee_jacobian[3:6, :] = dense_jacobian[(self.end_effector_index - 1) * 6: self.end_effector_index * 6 - 3, :self.robot.dof - 2]
+
+ inverse_jacobian = np.linalg.pinv(ee_jacobian, rcond=1e-2)
+ return inverse_jacobian @ twist
+
+ def internal_controller(self, qvel: np.ndarray) -> None:
+ """Control the robot dynamically to execute the given twist for one time step
+
+ This method will try to execute the joint velocity using the internal dynamics function in SAPIEN.
+
+ Note that this function is only used for one time step, so you may need to call it multiple times in your code
+ Also this controller is not perfect, it will still have some small movement even after you have finishing using
+ it. Thus try to wait for some steps using self.wait_n_steps(n) like in the hw2.py after you call it multiple
+ time to allow it to reach the target position
+
+ Args:
+ qvel: (7,) vector to represent the joint velocity
+
+ """
+ assert qvel.size == len(self.arm_joints)
+ target_qpos = qvel * self.timestep + self.robot.get_drive_target()[:-2]
+ for i, joint in enumerate(self.arm_joints):
+ joint.set_drive_velocity_target(qvel[i])
+ joint.set_drive_target(target_qpos[i])
+ passive_force = self.robot.compute_passive_force()
+ self.robot.set_qf(passive_force)
+
+ def calculate_twist(self, time_to_target, target_ee_pose):
+ relative_transform = self.end_effector.get_pose().inv().to_transformation_matrix() @ target_ee_pose
+ unit_twist, theta = pose2exp_coordinate(relative_transform)
+ velocity = theta / time_to_target
+ body_twist = unit_twist * velocity
+ current_ee_pose = self.end_effector.get_pose().to_transformation_matrix()
+ return adjoint_matrix(current_ee_pose) @ body_twist
+
+ def set_pose(self, target_ee_pose:np.ndarray, gripper_depth:float):
+ # target_ee_pose: (4, 4) transformation of robot tcp in world frame, as grasp pose
+ root_pose = np.identity(4)
+ root_pose[:3, :3] = target_ee_pose[:3, :3] @ self.g2g
+ root_pose[:3, 3] = target_ee_pose[:3, 3]
+ root_pose[:3, 3] -= (self.tcp2ee_length * self.scale - gripper_depth) * root_pose[:3, 2]
+ root_pose = np.linalg.inv(self.root2ee) @ root_pose
+ self.robot.set_root_pose(Pose().from_transformation_matrix(root_pose))
+
+ def get_pose(self, gripper_depth:float) -> np.ndarray:
+ target_ee_pose = np.identity(4)
+ # root_pose = self.robot.get_root_pose().to_transformation_matrix()
+ root_pose = self.end_effector.get_pose().to_transformation_matrix()
+ target_ee_pose[:3, 3] = root_pose[:3, 3] + (self.tcp2ee_length * self.scale - gripper_depth) * root_pose[:3, 2]
+ target_ee_pose[:3, :3] = root_pose[:3, :3] @ np.linalg.inv(self.g2g)
+ return target_ee_pose
+
+ def move_to_target_pose(self, target_ee_pose: np.ndarray, gripper_depth:float, num_steps: int, visu=None, vis_gif=False, vis_gif_interval=200, cam=None) -> None:
+ """
+ Move the robot hand dynamically to a given target pose
+ Args:
+ target_ee_pose: (4, 4) transformation of robot tcp in world frame, as grasp pose
+ num_steps: how much steps to reach to target pose,
+ each step correspond to self.scene.get_timestep() seconds
+ in physical simulation
+ """
+ if visu:
+ waypoints = []
+ if vis_gif:
+ imgs = []
+
+ executed_time = num_steps * self.timestep
+
+ target_ee_root_pose = np.identity(4)
+ target_ee_root_pose[:3, :3] = target_ee_pose[:3, :3] @ self.g2g
+ target_ee_root_pose[:3, 3] = target_ee_pose[:3, 3]
+ target_ee_root_pose[:3, 3] -= (self.tcp2ee_length * self.scale - gripper_depth) * target_ee_root_pose[:3, 2]
+
+ spatial_twist = self.calculate_twist(executed_time, target_ee_root_pose)
+ for i in range(num_steps):
+ if i % 100 == 0:
+ spatial_twist = self.calculate_twist((num_steps - i) * self.timestep, target_ee_root_pose)
+ qvel = self.compute_joint_velocity_from_twist(spatial_twist)
+ self.internal_controller(qvel)
+ self.env.step()
+ self.env.render()
+ if visu and i % 200 == 0:
+ waypoints.append(self.robot.get_qpos().tolist())
+ if vis_gif and ((i + 1) % vis_gif_interval == 0):
+ rgb_pose, _ = cam.get_observation()
+ fimg = (rgb_pose*255).astype(np.uint8)
+ fimg = Image.fromarray(fimg)
+ imgs.append(fimg)
+ if vis_gif and (i == 0):
+ rgb_pose, _ = cam.get_observation()
+ fimg = (rgb_pose*255).astype(np.uint8)
+ fimg = Image.fromarray(fimg)
+ for idx in range(5):
+ imgs.append(fimg)
+
+ if visu and not vis_gif:
+ return waypoints
+ if vis_gif and not visu:
+ return imgs
+ if visu and vis_gif:
+ return imgs, waypoints
+
+ def move_to_target_qvel(self, qvel) -> None:
+
+ """
+ Move the robot hand dynamically to a given target pose
+ Args:
+ target_ee_pose: (4, 4) transformation of robot hand in robot base frame (ee2base)
+ num_steps: how much steps to reach to target pose,
+ each step correspond to self.scene.get_timestep() seconds
+ in physical simulation
+ """
+ assert qvel.size == len(self.arm_joints)
+ for idx_step in range(100):
+ target_qpos = qvel * self.timestep + self.robot.get_drive_target()[:-2]
+ for i, joint in enumerate(self.arm_joints):
+ joint.set_drive_velocity_target(qvel[i])
+ joint.set_drive_target(target_qpos[i])
+ passive_force = self.robot.compute_passive_force()
+ self.robot.set_qf(passive_force)
+ self.env.step()
+ self.env.render()
+ return
+
+
+ def close_gripper(self):
+ for joint in self.gripper_joints:
+ joint.set_drive_target(0.0)
+
+ def open_gripper(self):
+ for joint in self.gripper_joints:
+ joint.set_drive_target(self.max_gripper_width / 2.0)
+
+ def set_gripper(self, width:float):
+ joint_angles = []
+ for j in self.robot.get_joints():
+ if j.get_dof() == 1:
+ if j.get_name().startswith("panda_finger_joint"):
+ joint_angles.append(max(min(width, self.max_gripper_width) / 2.0, 0.0))
+ else:
+ joint_angles.append(0)
+ self.robot.set_qpos(joint_angles)
+ for joint in self.gripper_joints:
+ joint.set_drive_target(max(min(width, self.max_gripper_width) / 2.0, 0.0))
+
+ def get_gripper(self) -> float:
+ joint_angles = self.robot.get_qpos()
+ width = 0.0
+ j_idx = 0
+ for j in self.robot.get_joints():
+ if j.get_dof() == 1:
+ if j.get_name().startswith("panda_finger_joint"):
+ width += joint_angles[j_idx]
+ else:
+ pass
+ j_idx += 1
+ return width
+
+ def clear_velocity_command(self):
+ for joint in self.arm_joints:
+ joint.set_drive_velocity_target(0)
+
+ def wait_n_steps(self, n: int, visu=None, vis_gif=False, vis_gif_interval=200, cam=None):
+ imgs = []
+ if visu:
+ waypoints = []
+ self.clear_velocity_command()
+ for i in range(n):
+ passive_force = self.robot.compute_passive_force()
+ self.robot.set_qf(passive_force)
+ self.env.step()
+ self.env.render()
+ if visu and i % 200 == 0:
+ waypoints.append(self.robot.get_qpos().tolist())
+ if vis_gif and ((i + 1) % vis_gif_interval == 0):
+ rgb_pose, _ = cam.get_observation()
+ fimg = (rgb_pose*255).astype(np.uint8)
+ fimg = Image.fromarray(fimg)
+ imgs.append(fimg)
+ #
+ self.robot.set_qf([0] * self.robot.dof)
+ if visu and vis_gif:
+ return imgs, waypoints
+
+ if visu:
+ return waypoints
+ if vis_gif:
+ return imgs
+
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000..50e8b05
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,691 @@
+import configargparse
+import json
+from omegaconf import OmegaConf
+import random
+import os
+import logging
+import numpy as np
+import tqdm
+import time
+import datetime
+import torch
+import imageio
+from PIL import Image
+import transformations as tf
+import open3d as o3d
+
+from envs.env import Env
+from envs.camera import Camera
+from envs.robot import Robot
+from utilities.env_utils import setup_seed
+from utilities.data_utils import transform_pc, transform_dir, read_joints_from_urdf_file, pc_noise
+from utilities.metrics_utils import calc_pose_error, invaffordance_metrics, invaffordances2affordance, calc_translation_error, calc_direction_error
+from utilities.constants import seed, max_grasp_width
+
+
+def config_parse() -> configargparse.Namespace:
+ parser = configargparse.ArgumentParser()
+
+ # environment config
+ parser.add_argument('--camera_config_path', type=str, default='./configs/data/camera_config.json', help='the path to camera config')
+ parser.add_argument('--num_config_per_object', type=int, default=200, help='the number of configs per object')
+ # data config
+ parser.add_argument('--scale', type=float, default=0, help='the scale of the object')
+ parser.add_argument('--data_path', type=str, default='/data2/junbo/where2act_modified_sapien_dataset/7167/mobility_vhacd.urdf', help='the path to the data')
+ parser.add_argument('--cat', type=str, default='Microwave', help='the category of the object')
+ parser.add_argument('--object_config_path', type=str, default='./configs/data/object_config.json', help='the path to object config')
+ # robot config
+ parser.add_argument('--robot_urdf_path', type=str, default='/data2/junbo/franka_panda/panda_gripper.urdf', help='the path to robot urdf')
+ parser.add_argument('--robot_scale', type=float, default=1, help='the scale of the robot')
+ # gt config
+ parser.add_argument('--gt_path', type=str, default='/data2/junbo/where2act_modified_sapien_dataset/7167/joint_abs_pose.json', help='the path to gt')
+ # model config
+ parser.add_argument('--roartnet', action='store_true', help='whether call roartnet')
+ parser.add_argument('--roartnet_config_path', type=str, default='./configs/eval_config.yaml', help='the path to roartnet config')
+ # grasp config
+ parser.add_argument('--graspnet', action='store_true', help='whether call graspnet')
+ parser.add_argument('--grasp', action='store_true', help='whether grasp')
+ parser.add_argument('--gsnet_weight_path', type=str, default='./weights/checkpoint_detection.tar', help='the path to graspnet weight')
+ parser.add_argument('--show_all_grasps', action='store_true', help='whether show all grasps detected by graspnet, note the visuals will harm the speed')
+ # task config
+ parser.add_argument('--selected_part', type=int, default=0, help='the selected part of the object')
+ # parser.add_argument('--manipulation', type=float, default=-20, help='the manipulation task, positive as push, negative as pull, revolute in degree, prismatic in cm')
+ parser.add_argument('--task', type=str, default='pull', choices=['pull', 'push', 'none'], help='the task')
+ parser.add_argument('--task_low', type=float, default=0.1, help='low bound of task')
+ parser.add_argument('--task_high', type=float, default=0.7, help='high bound of task')
+ parser.add_argument('--success_threshold', type=float, default=0.15, help='success threshold for ratio of manipulated movement')
+ # others
+ parser.add_argument('--gui', action='store_true', help='whether show gui')
+ parser.add_argument('--video', action='store_true', help='whether save video')
+ parser.add_argument('--output_path', type=str, default='./outputs/manipulation', help='the path to output')
+ parser.add_argument('--abbr', type=str, default='7167', help='the abbr of the object')
+ parser.add_argument('--seed', type=int, default=seed, help='the random seed')
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = config_parse()
+ setup_seed(args.seed)
+ # TODO: hardcode here to add noise to point cloud
+ output_name = 'noise_' + args.cat + '_' + args.abbr + '_' + args.task
+ if args.roartnet:
+ output_name += '_roartnet'
+ output_name += '_' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
+ output_path = os.path.join(args.output_path, output_name)
+ os.makedirs(output_path, exist_ok=True)
+ logger = logging.getLogger("manipulation")
+ logger.setLevel(level=logging.DEBUG)
+ handler = logging.FileHandler(filename=os.path.join(output_path, 'log.txt'))
+ logger.addHandler(handler)
+
+ if args.roartnet:
+ # load roartnet
+ start_time = time.time()
+ from models.roartnet import create_shot_encoder, create_encoder
+ from inference import inference_fn as roartnet_inference_fn
+ roartnet_cfg = OmegaConf.load(args.roartnet_config_path)
+ trained_path = roartnet_cfg.trained.path[args.cat]
+ trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml")
+ roartnet_cfg = OmegaConf.merge(trained_cfg, roartnet_cfg)
+ joint_num = roartnet_cfg.dataset.joint_num
+ resolution = roartnet_cfg.dataset.resolution
+ receptive_field = roartnet_cfg.dataset.receptive_field
+ has_rgb = roartnet_cfg.dataset.rgb
+ denoise = roartnet_cfg.dataset.denoise
+ normalize = roartnet_cfg.dataset.normalize
+ sample_points_num = roartnet_cfg.dataset.sample_points_num
+ sample_tuples_num = roartnet_cfg.algorithm.sampling.sample_tuples_num
+ tuple_more_num = roartnet_cfg.algorithm.sampling.tuple_more_num
+ shot_hidden_dims = roartnet_cfg.algorithm.shot_encoder.hidden_dims
+ shot_feature_dim = roartnet_cfg.algorithm.shot_encoder.feature_dim
+ shot_bn = roartnet_cfg.algorithm.shot_encoder.bn
+ shot_ln = roartnet_cfg.algorithm.shot_encoder.ln
+ shot_droput = roartnet_cfg.algorithm.shot_encoder.dropout
+ shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim,
+ shot_bn, shot_ln, shot_droput)
+ shot_encoder.load_state_dict(torch.load(f'{trained_path}/weights/shot_encoder_latest.pth', map_location=torch.device('cuda')))
+ shot_encoder = shot_encoder.cuda()
+ shot_encoder.eval()
+ overall_hidden_dims = roartnet_cfg.algorithm.encoder.hidden_dims
+ rot_bin_num = roartnet_cfg.algorithm.voting.rot_bin_num
+ overall_bn = roartnet_cfg.algorithm.encoder.bn
+ overall_ln = roartnet_cfg.algorithm.encoder.ln
+ overall_dropout = roartnet_cfg.algorithm.encoder.dropout
+ encoder = create_encoder(tuple_more_num, shot_feature_dim, has_rgb, overall_hidden_dims, rot_bin_num, joint_num,
+ overall_bn, overall_ln, overall_dropout)
+ encoder.load_state_dict(torch.load(f'{trained_path}/weights/encoder_latest.pth', map_location=torch.device('cuda')))
+ encoder = encoder.cuda()
+ encoder.eval()
+ voting_num = roartnet_cfg.algorithm.voting.voting_num
+ angle_tol = roartnet_cfg.algorithm.voting.angle_tol
+ translation2pc = roartnet_cfg.algorithm.voting.translation2pc
+ multi_candidate = roartnet_cfg.algorithm.voting.multi_candidate
+ candidate_threshold = roartnet_cfg.algorithm.voting.candidate_threshold
+ rotation_multi_neighbor = roartnet_cfg.algorithm.voting.rotation_multi_neighbor
+ neighbor_threshold = roartnet_cfg.algorithm.voting.neighbor_threshold
+ rotation_cluster = roartnet_cfg.algorithm.voting.rotation_cluster
+ bmm_size = roartnet_cfg.algorithm.voting.bmm_size
+ end_time = time.time()
+ logger.info(f"===> loaded roartnet {end_time - start_time}")
+
+ if args.graspnet:
+ # load graspnet
+ start_time = time.time()
+ from munch import DefaultMunch
+ from gsnet import AnyGrasp
+ grasp_detector_cfg = {
+ 'checkpoint_path': args.gsnet_weight_path,
+ 'max_gripper_width': max_grasp_width * args.robot_scale,
+ 'gripper_height': 0.03,
+ 'top_down_grasp': False,
+ 'add_vdistance': True,
+ 'debug': True
+ }
+ grasp_detector_cfg = DefaultMunch.fromDict(grasp_detector_cfg)
+ grasp_detector = AnyGrasp(grasp_detector_cfg)
+ grasp_detector.load_net()
+ end_time = time.time()
+ logger.info(f"===> loaded graspnet {end_time - start_time}")
+
+ # initialize environment
+ start_time = time.time()
+ env = Env(show_gui=args.gui)
+ if args.camera_config_path == 'none':
+ camera_config = None
+ cam = Camera(env, random_position=True, restrict_dir=True)
+ else:
+ with open(args.camera_config_path, "r") as fp:
+ camera_config = json.load(fp)
+ camera_near = camera_config['intrinsics']['near']
+ camera_far = camera_config['intrinsics']['far']
+ camera_width = camera_config['intrinsics']['width']
+ camera_height = camera_config['intrinsics']['height']
+ camera_fovx = np.random.uniform(camera_config['intrinsics']['fovx'][0], camera_config['intrinsics']['fovx'][1])
+ camera_fovy = np.random.uniform(camera_config['intrinsics']['fovy'][0], camera_config['intrinsics']['fovy'][1])
+ camera_dist = np.random.uniform(camera_config['extrinsics']['dist'][0], camera_config['extrinsics']['dist'][1])
+ camera_phi = np.random.uniform(camera_config['extrinsics']['phi'][0], camera_config['extrinsics']['phi'][1])
+ camera_theta = np.random.uniform(camera_config['extrinsics']['theta'][0], camera_config['extrinsics']['theta'][1])
+ cam = Camera(env, near=camera_near, far=camera_far, image_size=[camera_width, camera_height], fov=[camera_fovx, camera_fovy],
+ dist=camera_dist, phi=camera_phi / 180 * np.pi, theta=camera_theta / 180 * np.pi)
+ if args.gui:
+ env.set_controller_camera_pose(cam.pos[0], cam.pos[1], cam.pos[2], np.pi + cam.theta, -cam.phi)
+ object_material = env.get_material(4, 4, 0.01)
+ if args.object_config_path == 'none':
+ object_config = {
+ "scale_min": 1,
+ "scale_max": 1,
+ }
+ else:
+ with open(args.object_config_path, "r") as fp:
+ object_config = json.load(fp)
+ if args.gt_path != "none":
+ with open(args.gt_path, "r") as fp:
+ gt_config = json.load(fp)
+ max_object_size = 0
+ for joint_name in gt_config.keys():
+ if joint_name == 'aabb_min' or joint_name == 'aabb_max':
+ continue
+ max_object_size = max(max_object_size, np.max(np.array(gt_config[joint_name]["bbox_max"]) - np.array(gt_config[joint_name]["bbox_min"])))
+ if args.cat in object_config:
+ object_scale = object_config[args.cat]['size'] / max_object_size
+ joint_num = object_config[args.cat]['joint_num']
+ assert joint_num == len(gt_config.keys()) - 2
+ else:
+ object_scale = 1
+ joint_num = len(gt_config.keys()) - 2
+ if args.scale != 0:
+ object_scale = args.scale
+ else:
+ object_scale = args.scale
+ end_time = time.time()
+ logger.info(f"===> initialized environment {end_time - start_time}")
+
+ success_configs, fail_configs = [], []
+ for config_id in tqdm.trange(args.num_config_per_object):
+ # load object
+ this_config_scale = np.random.uniform(object_config["scale_min"], object_config["scale_max"]) * object_scale
+ video = []
+ start_time = time.time()
+ still = False
+ try_times = 0
+ while not still and try_times < 5:
+ goal_qpos, joint_abs_angles = env.load_object(args.data_path, object_material, state='random-middle-middle', target_part_id=-1, target_part_idx=args.selected_part, scale=this_config_scale)
+ env.render()
+
+ # check still and reach goal qpos
+ start_time = time.time()
+ still_timesteps = 0
+ wait_timesteps = 0
+ cur_qpos = env.get_object_qpos()
+ # while still_timesteps < 500 and wait_timesteps < 3000:
+ while still_timesteps < 500 and wait_timesteps < 5000:
+ env.step()
+ env.render()
+ cur_new_qpos = env.get_object_qpos()
+ invalid_contact = False
+ for c in env.scene.get_contacts():
+ for p in c.points:
+ if abs(p.impulse @ p.impulse) > 1e-4:
+ invalid_contact = True
+ break
+ if invalid_contact:
+ break
+ # if np.max(np.abs(cur_new_qpos - cur_qpos)) < 1e-6 and (not invalid_contact):
+ if np.max(np.abs(cur_new_qpos - cur_qpos)) < 1e-6 and np.max(np.abs(cur_new_qpos - goal_qpos)) < 0.02 and (not invalid_contact):
+ still_timesteps += 1
+ else:
+ still_timesteps = 0
+ cur_qpos = cur_new_qpos
+ wait_timesteps += 1
+ still = still_timesteps >= 500
+ if not still:
+ env.scene.remove_articulation(env.object)
+ try_times += 1
+ if not still:
+ logger.info(f"{config_id} failed to load object")
+ continue
+ end_time = time.time()
+ logger.info(f"===> {config_id} loaded object {end_time - start_time} {this_config_scale} {try_times}")
+
+ # set camera
+ if camera_config is None:
+ cam.change_pose(random_position=True, restrict_dir=True)
+ else:
+ camera_fovx = np.random.uniform(camera_config['intrinsics']['fovx'][0], camera_config['intrinsics']['fovx'][1])
+ camera_fovy = np.random.uniform(camera_config['intrinsics']['fovy'][0], camera_config['intrinsics']['fovy'][1])
+ cam.change_fov([camera_fovx, camera_fovy])
+ camera_dist = np.random.uniform(camera_config['extrinsics']['dist'][0], camera_config['extrinsics']['dist'][1])
+ camera_phi = np.random.uniform(camera_config['extrinsics']['phi'][0], camera_config['extrinsics']['phi'][1])
+ camera_theta = np.random.uniform(camera_config['extrinsics']['theta'][0], camera_config['extrinsics']['theta'][1])
+ cam.change_pose(dist=camera_dist, phi=camera_phi / 180 * np.pi, theta=camera_theta / 180 * np.pi)
+ if args.gui:
+ env.set_controller_camera_pose(cam.pos[0], cam.pos[1], cam.pos[2], np.pi + cam.theta, -cam.phi)
+ env.step()
+ env.render()
+ if args.video:
+ frame, _ = cam.get_observation()
+ frame = (frame * 255).astype(np.uint8)
+ frame = Image.fromarray(frame)
+ video.append(frame)
+
+ # get observation
+ rgb, depth = cam.get_observation()
+ cam_XYZA_id1, cam_XYZA_id2, cam_XYZA_pts = cam.compute_camera_XYZA(depth)
+ mask = cam.camera.get_segmentation() # used only for filtering out bad cases
+ object_mask = mask[cam_XYZA_id1, cam_XYZA_id2]
+ extrinsic = cam.get_metadata()['mat44']
+ R = extrinsic[:3, :3]
+ T = extrinsic[:3, 3]
+ pcd_cam = cam_XYZA_pts.copy()
+ pcd_world = (R @ cam_XYZA_pts.T).T + T
+ pcd_color = rgb[cam_XYZA_id1, cam_XYZA_id2]
+ c2c = np.array([[0, -1, 0, 0], [0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 0, 1]])
+ pcd_grasp = transform_pc(pcd_cam, c2c)
+ # TODO: hardcode here to add noise to point cloud
+ pcd_cam = pc_noise(pcd_cam, 0.2, 0.01, 0.002, 0.5)
+
+ # obtain gt
+ start_time = time.time()
+ # env.add_frame_visual()
+ # env.render()
+ if args.gt_path != "none":
+ # set task
+ movable_link_ids = env.movable_link_ids
+ object_joint_types = env.movable_link_joint_types
+ movable_link_joint_names = env.movable_link_joint_names
+ all_link_ids = env.all_link_ids
+ all_link_names = env.all_link_names
+ joint_real_states = env.get_object_qpos()
+ joint_lower_states = env.joint_angles_lower
+ joint_upper_states = env.joint_angles_upper
+ env.set_target_object_part_actor_id(movable_link_ids[args.selected_part])
+ if object_joint_types[args.selected_part] == 0:
+ target_joint_real_state = joint_real_states[args.selected_part] / np.pi * 180
+ target_joint_lower_state = joint_lower_states[args.selected_part] / np.pi * 180
+ target_joint_upper_state = joint_upper_states[args.selected_part] / np.pi * 180
+ elif object_joint_types[args.selected_part] == 1:
+ target_joint_real_state = joint_real_states[args.selected_part] * 100
+ target_joint_lower_state = joint_lower_states[args.selected_part] * 100
+ target_joint_upper_state = joint_upper_states[args.selected_part] * 100
+ else:
+ raise ValueError(f"invalid joint type {object_joint_types[args.selected_part]}")
+ if args.task == 'push':
+ task_high = min(args.task_high * (target_joint_upper_state - target_joint_lower_state), target_joint_real_state - target_joint_lower_state - 15)
+ elif args.task == 'pull':
+ task_high = min(args.task_high * (target_joint_upper_state - target_joint_lower_state), target_joint_upper_state - target_joint_real_state - 5)
+ elif args.task == 'none':
+ task_high = target_joint_upper_state - target_joint_lower_state
+ else:
+ raise ValueError(f"invalid task {args.task}")
+ task_low = args.task_low * (target_joint_upper_state - target_joint_lower_state)
+ if task_high <= 0 or task_low >= task_high:
+ logger.info(f"{config_id} cannot set task {target_joint_upper_state - target_joint_lower_state} {target_joint_real_state - target_joint_lower_state}")
+ env.scene.remove_articulation(env.object)
+ continue
+ if args.task == 'push':
+ task_state = random.random() * (task_high - task_low) + task_low
+ elif args.task == 'pull':
+ task_state = random.random() * (task_high - task_low) + task_low
+ task_state *= -1
+ elif args.task == 'none':
+ task_state = 0
+ else:
+ raise ValueError(f"invalid task {args.task}")
+ assert len(movable_link_ids) == len(object_joint_types) and len(object_joint_types) == len(movable_link_joint_names)
+ assert len(movable_link_ids) == len(joint_real_states) and len(joint_real_states) == len(joint_lower_states) and len(joint_lower_states) == len(joint_upper_states)
+ logger.info(f"{config_id} task {task_state}")
+
+ joint_bases, joint_directions, joint_types, joint_res, joint_states, affordable_positions = [], [], [], [], [], []
+ for idx, link_id in enumerate(movable_link_ids):
+ joint_pose_meta = gt_config[movable_link_joint_names[idx]]
+ joint_base = np.asarray(joint_pose_meta['base_position'], order='F') * this_config_scale
+ joint_bases.append(joint_base)
+ # env.add_point_visual(joint_base, color=[0, 1, 0], name='joint_base_{}'.format(idx))
+ joint_direction = np.asarray(joint_pose_meta['direction'], order='F')
+ joint_directions.append(joint_direction)
+ # env.add_line_visual(joint_base, joint_base + this_config_scale * joint_direction, color=[0, 1, 0], name='joint_direction_{}'.format(idx))
+ joint_type = joint_pose_meta['joint_type']
+ assert joint_type == object_joint_types[idx]
+ joint_types.append(joint_type)
+ if joint_type == 0:
+ joint_re = joint_pose_meta["joint_re"]
+ elif joint_type == 1:
+ joint_re = 0
+ else:
+ raise ValueError(f"invalid joint type {joint_pose_meta['joint_type']}")
+ joint_res.append(joint_re)
+ joint_state = joint_real_states[idx] - joint_lower_states[idx]
+ joint_states.append(joint_state)
+ affordable_position = np.asarray(joint_pose_meta['affordable_position'], order='F') * this_config_scale
+ if joint_type == 0:
+ transformation_matrix = tf.rotation_matrix(angle=joint_state * joint_re, direction=joint_direction, point=joint_base)
+ elif joint_type == 1:
+ transformation_matrix = tf.translation_matrix(joint_state * joint_direction)
+ else:
+ raise ValueError(f"invalid joint type {joint_pose_meta['joint_type']}")
+ affordable_position = transform_pc(affordable_position[None, :], transformation_matrix)[0]
+ affordable_positions.append(affordable_position)
+ # env.add_point_visual(affordable_position, color=[0, 0, 1], name='affordable_position_{}'.format(idx))
+ if joint_type == 0:
+ logger.info(f"added joint {movable_link_joint_names[idx]} revolute {'pull_counterclockwise' if joint_re == 1 else 'pull_clockwise'} {joint_state / np.pi * 180} {(joint_upper_states[idx] - joint_lower_states[idx]) / np.pi * 180}")
+ elif joint_type == 1:
+ logger.info(f"added joint {movable_link_joint_names[idx]} prismatic {joint_state * 100} {(joint_upper_states[idx] - joint_lower_states[idx]) * 100}")
+ else:
+ raise ValueError(f"invalid joint type {joint_pose_meta['joint_type']}")
+ selected_joint_base = joint_bases[args.selected_part]
+ selected_joint_direction = joint_directions[args.selected_part]
+ selected_joint_type = joint_types[args.selected_part]
+ selected_joint_re = joint_res[args.selected_part]
+ selected_joint_state = joint_states[args.selected_part]
+ selected_affordable_position = affordable_positions[args.selected_part]
+
+ # merge affiliated parts into parents
+ joints_dict = read_joints_from_urdf_file(args.data_path)
+ link_graph = {}
+ for joint_name in joints_dict:
+ link_graph[joints_dict[joint_name]['child']] = joints_dict[joint_name]['parent']
+ mask_ins = np.unique(object_mask)
+ fixed_id = -1
+ for mask_in in mask_ins:
+ if mask_in not in movable_link_ids:
+ link_name = all_link_names[all_link_ids.index(mask_in)]
+ parent_link = link_graph[link_name]
+ parent_link_id = all_link_ids[all_link_names.index(parent_link)]
+ if parent_link_id in mask_ins:
+ if parent_link_id in movable_link_ids:
+ object_mask[object_mask == mask_in] = parent_link_id
+ else:
+ # TODO: may be error
+ if fixed_id == -1:
+ fixed_id = parent_link_id
+ object_mask[object_mask == mask_in] = fixed_id
+ else:
+ # TODO: may be error
+ if fixed_id == -1:
+ fixed_id = parent_link_id
+ object_mask[object_mask == mask_in] = fixed_id
+ mask_ins = np.unique(object_mask)
+ fixed_num = 0
+ for mask_in in mask_ins:
+ if mask_in not in movable_link_ids:
+ fixed_num += 1
+ assert fixed_num <= 1
+ assert mask_ins.shape[0] <= len(movable_link_ids) + 1
+
+ # rearrange mask
+ instance_mask = np.zeros_like(object_mask)
+ real_id = 1
+ selected_joint_idxs = []
+ for mask_id, movable_id in zip(movable_link_ids, range(len(movable_link_ids))):
+ if mask_id in mask_ins:
+ instance_mask[object_mask == mask_id] = real_id
+ real_id += 1
+ selected_joint_idxs.append(movable_id)
+
+ # check regular
+ if len(selected_joint_idxs) != joint_num or np.min(instance_mask) != 0 or np.max(instance_mask) != joint_num:
+ logger.info(f"{config_id} irregular")
+ env.scene.remove_articulation(env.object)
+ continue
+
+ # check seen
+ if (instance_mask == (args.selected_part + 1)).sum() < 0.15 * instance_mask.shape[0]:
+ logger.info(f"{config_id} unseen")
+ env.scene.remove_articulation(env.object)
+ continue
+
+ # check suitable
+ affordable_dist = np.linalg.norm(pcd_world - selected_affordable_position, axis=-1).min()
+ if affordable_dist > 0.03:
+ logger.info(f"{config_id} unsuitable")
+ env.scene.remove_articulation(env.object)
+ continue
+
+ for _ in range(100):
+ env.step()
+ env.render()
+ if args.video:
+ frame, _ = cam.get_observation()
+ frame = (frame * 255).astype(np.uint8)
+ frame = Image.fromarray(frame)
+ video.append(frame)
+ end_time = time.time()
+ logger.info(f"===> added gt {end_time - start_time}")
+
+ # prediction
+ if args.roartnet:
+ start_time = time.time()
+ pred_joint_bases, pred_joint_directions, pred_affordable_positions = roartnet_inference_fn(pcd_cam, pcd_color if has_rgb else None, shot_encoder, encoder,
+ denoise, normalize, resolution, receptive_field, sample_points_num, sample_tuples_num, tuple_more_num,
+ voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold, bmm_size, joint_num, device=0)
+ pred_selected_joint_base = pred_joint_bases[args.selected_part]
+ pred_selected_joint_direction = pred_joint_directions[args.selected_part]
+ pred_selected_affordable_position = pred_affordable_positions[args.selected_part]
+ pred_selected_joint_base = transform_pc(pred_selected_joint_base[None, :], extrinsic)[0]
+ pred_selected_joint_direction = transform_dir(pred_selected_joint_direction[None, :], extrinsic)[0]
+ pred_selected_affordable_position = transform_pc(pred_selected_affordable_position[None, :], extrinsic)[0]
+ joint_translation_errors = calc_translation_error(pred_selected_joint_base, selected_joint_base, pred_selected_joint_direction, selected_joint_direction)
+ joint_direction_error = calc_direction_error(pred_selected_joint_direction, selected_joint_direction)
+ affordance_error, _, _, _, _ = calc_translation_error(pred_selected_affordable_position, selected_affordable_position, None, None)
+ selected_joint_base = pred_selected_joint_base
+ selected_joint_direction = pred_selected_joint_direction
+ selected_affordable_position = pred_selected_affordable_position
+ end_time = time.time()
+ logger.info(f"===> {config_id} roartnet predicted {end_time - start_time} {joint_translation_errors} {joint_direction_error} {affordance_error}")
+
+ # obtain grasp
+ if args.graspnet:
+ start_time = time.time()
+ # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=None, lims=[
+ # np.floor(np.min(pcd_grasp[:, 0])) - 0.1, np.ceil(np.max(pcd_grasp[:, 0])) + 0.1,
+ # np.floor(np.min(pcd_grasp[:, 1])) - 0.1, np.ceil(np.max(pcd_grasp[:, 1])) + 0.1,
+ # np.floor(np.min(pcd_grasp[:, 2])) - 0.1, np.ceil(np.max(pcd_grasp[:, 2])) + 0.1])
+ # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=[-float('inf'), float('inf'), -float('inf'), float('inf'), -float('inf'), float('inf')])
+ # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, apply_object_mask=True, dense_grasp=False, collision_detection=True)
+ try:
+ gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='fast')
+ except:
+ gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='slow')
+ if gg_grasp is None:
+ gg_grasp = []
+ else:
+ if len(gg_grasp) != 2:
+ gg_grasp = []
+ else:
+ gg_grasp, pcd_o3d = gg_grasp
+ gg_grasp = gg_grasp.nms().sort_by_score()
+ # if args.gui:
+ # grippers_o3d = gg_grasp.to_open3d_geometry_list()
+ # frame_o3d = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
+ # o3d.visualization.draw_geometries([*grippers_o3d, pcd_o3d, frame_o3d])
+ end_time = time.time()
+ logger.info(f"===> {config_id} obtained grasp {end_time - start_time} {len(gg_grasp)}")
+ if len(gg_grasp) == 0:
+ logger.info(f"{config_id} no grasp detected")
+ env.scene.remove_articulation(env.object)
+ continue
+
+ # add visuals
+ start_time = time.time()
+ grasp_scores, grasp_widths, grasp_depths, grasp_translations, grasp_rotations, grasp_invaffordances = [], [], [], [], [], []
+ for g_idx, g_grasp in enumerate(gg_grasp):
+ grasp_score = g_grasp.score
+ grasp_scores.append(grasp_score)
+ grasp_width = g_grasp.width
+ grasp_widths.append(grasp_width)
+ grasp_depth = g_grasp.depth
+ grasp_depths.append(grasp_depth)
+ grasp_translation = g_grasp.translation
+ grasp_rotation = g_grasp.rotation_matrix
+ grasp_transformation = np.identity(4)
+ grasp_transformation[:3, :3] = grasp_rotation
+ grasp_transformation[:3, 3] = grasp_translation
+ grasp_transformation = extrinsic @ np.linalg.inv(c2c) @ grasp_transformation
+ grasp_translation = grasp_transformation[:3, 3]
+ grasp_translations.append(grasp_translation)
+ grasp_rotation = grasp_transformation[:3, :3]
+ grasp_rotations.append(grasp_rotation)
+ grasp_invaffordance = invaffordance_metrics(grasp_translation, grasp_rotation, grasp_score, selected_affordable_position,
+ selected_joint_base, selected_joint_direction, selected_joint_type)
+ grasp_invaffordances.append(grasp_invaffordance)
+ grasp_affordances = invaffordances2affordance(grasp_invaffordances)
+ if args.show_all_grasps:
+ for g_idx, g_grasp in enumerate(gg_grasp):
+ # env.add_grasp_visual(grasp_widths[g_idx], grasp_depths[g_idx], grasp_translations[g_idx], grasp_rotations[g_idx], affordance=(grasp_affordances[g_idx] - min(grasp_affordances)) / (max(grasp_affordances) - min(grasp_affordances)), name='grasp_{}'.format(g_idx))
+ print("added grasp", grasp_affordances[g_idx])
+ selected_grasp_idx = np.argmax(grasp_affordances)
+ selected_grasp_score = grasp_scores[selected_grasp_idx]
+ selected_grasp_width = grasp_widths[selected_grasp_idx]
+ selected_grasp_width = max(min(selected_grasp_width * 1.5, max_grasp_width * args.robot_scale), 0.0)
+ selected_grasp_depth = grasp_depths[selected_grasp_idx]
+ selected_grasp_translation = grasp_translations[selected_grasp_idx]
+ selected_grasp_rotation = grasp_rotations[selected_grasp_idx]
+ selected_grasp_affordance = grasp_affordances[selected_grasp_idx]
+ # env.add_grasp_visual(selected_grasp_width, selected_grasp_depth, selected_grasp_translation, selected_grasp_rotation, affordance=selected_grasp_affordance, name='selected_grasp')
+ selected_grasp_pose = np.identity(4)
+ selected_grasp_pose[:3, :3] = selected_grasp_rotation
+ selected_grasp_pose[:3, 3] = selected_grasp_translation
+ for _ in range(100):
+ env.step()
+ env.render()
+ if args.video:
+ frame, _ = cam.get_observation()
+ frame = (frame * 255).astype(np.uint8)
+ frame = Image.fromarray(frame)
+ video.append(frame)
+ end_time = time.time()
+ logger.info(f"===> {config_id} added visuals {end_time - start_time}")
+
+ # load robot
+ if args.grasp:
+ start_time = time.time()
+ robot_material = env.get_material(4, 4, 0.01)
+ robot_move_steps_per_unit = 1000
+ robot_short_wait_steps = 10
+ robot_long_wait_steps = 1000
+ robot = Robot(env, args.robot_urdf_path, robot_material, open_gripper=False, scale=args.robot_scale)
+ env.end_checking_contact(robot.hand_actor_id, robot.gripper_actor_ids, False)
+ selected_grasp_pre_pose = selected_grasp_pose.copy()
+ selected_grasp_pre_pose[:3, 3] -= 0.1 * selected_grasp_pre_pose[:3, 0]
+ robot.set_gripper(selected_grasp_width)
+ robot.set_pose(selected_grasp_pre_pose, selected_grasp_depth)
+ # env.add_point_visual(selected_grasp_pre_pose[:3, 3], color=[1, 0, 0], radius=0.01, name='selected_grasp_pre_translation')
+ for _ in range(100):
+ env.step()
+ env.render()
+ if args.video:
+ frame, _ = cam.get_observation()
+ frame = (frame * 255).astype(np.uint8)
+ frame = Image.fromarray(frame)
+ video.append(frame)
+ read_width = robot.get_gripper()
+ read_pose = robot.get_pose(selected_grasp_depth)
+ pose_error = calc_pose_error(selected_grasp_pre_pose, read_pose)
+ end_time = time.time()
+ logger.info(f"===> {config_id} loaded robot {end_time - start_time} {read_width - selected_grasp_width} {pose_error}")
+
+ # grasp
+ start_time = time.time()
+ frames = robot.move_to_target_pose(selected_grasp_pose, selected_grasp_depth, robot_move_steps_per_unit * 10, vis_gif=args.video, cam=cam)
+ if args.video:
+ video.extend(frames)
+ frames = robot.wait_n_steps(robot_short_wait_steps, vis_gif=args.video, cam=cam)
+ if args.video:
+ video.extend(frames)
+ robot.close_gripper()
+ frames = robot.wait_n_steps(robot_long_wait_steps, vis_gif=args.video, cam=cam)
+ if args.video:
+ video.extend(frames)
+ current_width = robot.get_gripper()
+ success_grasp = env.check_contact_right()
+ read_pose = robot.get_pose(selected_grasp_depth)
+ pose_error = calc_pose_error(selected_grasp_pose, read_pose)
+ # env.add_point_visual(selected_grasp_pose[:3, 3], color=[0, 0, 1], radius=0.01, name='selected_grasp_translation')
+ # env.render()
+ # if args.video:
+ # frame, _ = cam.get_observation()
+ # frame = (frame * 255).astype(np.uint8)
+ # frame = Image.fromarray(frame)
+ # video.append(frame)
+ end_time = time.time()
+ logger.info(f"===> {config_id} grasped {end_time - start_time} {pose_error} {success_grasp}")
+
+ # manipulation
+ if task_state != 0:
+ plan_trajectory, real_trajectory, semi_trajectory = [], [], []
+ plan_current_pose = robot.get_pose(selected_grasp_depth)
+ joint_state_initial = env.get_target_part_state()
+ if selected_joint_type == 0:
+ joint_state_initial = joint_state_initial / np.pi * 180
+ elif selected_joint_type == 1:
+ joint_state_initial = joint_state_initial * 100
+ else:
+ raise ValueError(f"invalid joint type {selected_joint_type}")
+ joint_state_task = joint_state_initial - task_state * (1 - args.success_threshold)
+ start_time = time.time()
+ for step in tqdm.trange(int(np.ceil(np.abs(task_state / 2.0) * 1.5))):
+ current_pose = robot.get_pose(selected_grasp_depth)
+ if selected_joint_type == 0:
+ rotation_angle = -2.0 * np.sign(task_state) * selected_joint_re / 180.0 * np.pi
+ delta_pose = tf.rotation_matrix(angle=rotation_angle, direction=selected_joint_direction, point=selected_joint_base)
+ elif selected_joint_type == 1:
+ translation_distance = -2.0 * np.sign(task_state) / 100.0
+ delta_pose = tf.translation_matrix(selected_joint_direction * translation_distance)
+ else:
+ raise ValueError(f"invalid joint type {selected_joint_type}")
+ # next_pose = delta_pose @ plan_current_pose
+ next_pose = delta_pose @ current_pose
+ frames = robot.move_to_target_pose(next_pose, selected_grasp_depth, robot_move_steps_per_unit * 2, vis_gif=args.video, cam=cam)
+ if args.video:
+ video.extend(frames)
+ # robot.wait_n_steps(robot_short_wait_steps)
+ read_pose = robot.get_pose(selected_grasp_depth)
+ pose_error = calc_pose_error(next_pose, read_pose)
+ current_joint_state = env.get_target_part_state()
+ if selected_joint_type == 0:
+ current_joint_state = current_joint_state / np.pi * 180
+ elif selected_joint_type == 1:
+ current_joint_state = current_joint_state * 100
+ else:
+ raise ValueError(f"invalid joint type {selected_joint_type}")
+ joint_state_error = current_joint_state - joint_state_task
+ logger.info(f"{pose_error} {joint_state_error}")
+ real_trajectory.append(read_pose)
+ semi_trajectory.append(next_pose)
+ plan_next_pose = delta_pose @ plan_current_pose
+ plan_current_pose = plan_next_pose.copy()
+ plan_trajectory.append(plan_next_pose)
+ success_manipulation = ((task_state < 0) and (joint_state_error > 0)) or ((task_state > 0) and (joint_state_error < 0))
+ if success_manipulation:
+ break
+ # for step in tqdm.trange(int(np.ceil(np.abs(task_state / 2.0)))):
+ # if step % 2 != 0:
+ # continue
+ # env.add_grasp_visual(current_width, selected_grasp_depth, plan_trajectory[step][:3, 3], plan_trajectory[step][:3, :3], affordance=1, name='grasp_{}'.format(-step))
+ # env.add_grasp_visual(current_width, selected_grasp_depth, real_trajectory[step][:3, 3], real_trajectory[step][:3, :3], affordance=0, name='grasp_{}_real'.format(-step))
+ # env.add_grasp_visual(current_width, selected_grasp_depth, semi_trajectory[step][:3, 3], semi_trajectory[step][:3, :3], affordance=0.99, name='grasp_{}_semi'.format(-step))
+ # env.render()
+ # if args.video:
+ # frame, _ = cam.get_observation()
+ # frame = (frame * 255).astype(np.uint8)
+ # frame = Image.fromarray(frame)
+ # video.append(frame)
+ end_time = time.time()
+ logger.info(f"===> {config_id} manipulated {end_time - start_time} {success_manipulation}")
+ if success_manipulation:
+ success_configs.append(config_id)
+ else:
+ fail_configs.append(config_id)
+ env.scene.remove_articulation(robot.robot)
+ env.scene.remove_articulation(env.object)
+
+ if args.video:
+ imageio.mimsave(os.path.join(output_path, f'video_{str(config_id).zfill(2)}_{str(success_manipulation)}.mp4'), video)
+
+ logger.info(f"===> success configs {success_configs} {len(success_configs)}")
+ logger.info(f"===> fail configs {fail_configs} {len(fail_configs)}")
+ logger.info(f"===> success rate {len(success_configs) / (len(success_configs) + len(fail_configs))}")
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..3e606d0
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,331 @@
+from typing import Tuple, List, Optional
+from omegaconf import OmegaConf
+import os
+import tqdm
+import itertools
+from itertools import combinations
+import random
+import numpy as np
+import torch
+import torch.nn as nn
+import cupy as cp
+# import cudf
+# import cuml
+from sklearn.cluster import KMeans
+# from cuml.cluster import DBSCAN
+# from cuml.common.device_selection import using_device_type
+import MinkowskiEngine as ME
+import open3d as o3d
+
+from models.roartnet import create_shot_encoder, create_encoder
+from models.voting import ppf_kernel, rot_voting_kernel, ppf4d_kernel
+from utilities.metrics_utils import calc_translation_error, calc_direction_error
+from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting, visualize_confidence_voting
+from utilities.data_utils import pc_normalize, farthest_point_sample, fibonacci_sphere
+from utilities.env_utils import setup_seed
+from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color
+from src_shot.build import shot
+
+
+def voting_translation(pc:np.ndarray, tr_offsets:np.ndarray, point_idxs:np.ndarray, confs:np.ndarray,
+ resolution:float, voting_num:int, device:int,
+ translation2pc:bool, multi_candidate:bool, candidate_threshold:float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ # pc: (N, 3), tr_offsets: (N_t, 2), point_idxs: (N_t, 2), confs: (N_t,)
+ block_size = (tr_offsets.shape[0] + 512 - 1) // 512
+ pc_min = np.min(pc, 0)
+ pc_max = np.max(pc, 0)
+ corner_min = pc_min - (pc_max - pc_min)
+ corner_max = pc_max + (pc_max - pc_min)
+ corners = np.stack([corner_min, corner_max])
+ grid_res = ((corners[1] - corners[0]) / resolution).astype(np.int32) + 1
+
+ with cp.cuda.Device(device):
+ grid_obj = cp.asarray(np.zeros(grid_res, dtype=np.float32))
+
+ ppf_kernel(
+ (block_size, 1, 1),
+ (512, 1, 1),
+ (
+ cp.ascontiguousarray(cp.asarray(pc).astype(cp.float32)),
+ cp.ascontiguousarray(cp.asarray(tr_offsets).astype(cp.float32)),
+ cp.ascontiguousarray(cp.asarray(confs).astype(cp.float32)),
+ cp.ascontiguousarray(cp.asarray(point_idxs).astype(cp.int32)),
+ grid_obj,
+ cp.ascontiguousarray(cp.asarray(corners[0]).astype(cp.float32)),
+ cp.float32(resolution),
+ cp.int32(tr_offsets.shape[0]),
+ cp.int32(voting_num),
+ cp.int32(grid_obj.shape[0]),
+ cp.int32(grid_obj.shape[1]),
+ cp.int32(grid_obj.shape[2])
+ )
+ )
+
+ if not multi_candidate:
+ cand = cp.array(cp.unravel_index(cp.array([cp.argmax(grid_obj, axis=None)]), grid_obj.shape)).T[::-1]
+ cand_world = cp.asarray(corners[0]) + cand * resolution
+ else:
+ indices = cp.indices(grid_obj.shape)
+ indices_list = cp.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape))
+ votes_list = grid_obj.reshape(-1)
+ grid_pc = cp.asarray(corners[0]) + indices_list * resolution
+ normalized_votes_list = votes_list / cp.max(votes_list)
+ candidates = grid_pc[normalized_votes_list >= candidate_threshold]
+ candidate_weights = normalized_votes_list[normalized_votes_list >= candidate_threshold]
+ candidate_weights = candidate_weights / cp.sum(candidate_weights)
+ cand_world = cp.sum(candidates * candidate_weights[:, None], axis=0)[None, :]
+
+ if translation2pc:
+ pc_cp = cp.asarray(pc)
+ best_idx = cp.linalg.norm(pc_cp - cand_world, axis=-1).argmin()
+ translation = pc_cp[best_idx]
+ else:
+ translation = cand_world[0]
+
+ return (translation.get(), grid_obj.get(), corners)
+
+def voting_rotation(pc:np.ndarray, rot_offsets:np.ndarray, point_idxs:np.ndarray, confs:np.ndarray,
+ rot_candidate_num:int, angle_tol:float, voting_num:int, bmm_size:int, device:int,
+ multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool, kmeans:Optional[KMeans],
+ rotation_multi_neighbor:bool, neighbor_threshold:float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ # pc: (N, 3), rot_offsets: (N_t,), point_idxs: (N_t, 2), confs: (N_t,)
+ block_size = (rot_offsets.shape[0] + 512 - 1) // 512
+ sphere_pts = np.array(fibonacci_sphere(rot_candidate_num))
+ expanded_confs = confs[:, None].repeat(voting_num, axis=-1).reshape(-1, 1)
+
+ with cp.cuda.Device(device):
+ candidates = cp.zeros((rot_offsets.shape[0], voting_num, 3), cp.float32)
+
+ rot_voting_kernel(
+ (block_size, 1, 1),
+ (512, 1, 1),
+ (
+ cp.ascontiguousarray(cp.asarray(pc).astype(cp.float32)),
+ cp.ascontiguousarray(cp.asarray(rot_offsets).astype(cp.float32)),
+ candidates,
+ cp.ascontiguousarray(cp.asarray(point_idxs).astype(cp.int32)),
+ cp.int32(rot_offsets.shape[0]),
+ cp.int32(voting_num)
+ )
+ )
+
+ candidates = candidates.get().reshape(-1, 3)
+
+ with torch.no_grad():
+ candidates = torch.from_numpy(candidates).cuda(device)
+ expanded_confs = torch.from_numpy(expanded_confs).cuda(device)
+ sph_cp = torch.tensor(sphere_pts.T, dtype=torch.float32).cuda(device)
+ counts = torch.zeros((sphere_pts.shape[0],), dtype=torch.float32).cuda(device) # (rot_candidate_num,)
+
+ for i in range((candidates.shape[0] - 1) // bmm_size + 1):
+ cos = candidates[i * bmm_size:(i + 1) * bmm_size].mm(sph_cp) # (bmm_size, rot_candidate_num)
+ if not rotation_multi_neighbor:
+ voting = (cos > np.cos(2 * angle_tol / 180 * np.pi)).float() # (bmm_size, rot_candidate_num)
+ else:
+ # voting_indices = torch.topk(cos, neighbors_num, dim=-1)[1]
+ # voting_mask = torch.zeros_like(cos)
+ # voting_mask.scatter_(1, voting_indices, 1)
+ voting_mask = (cos > np.cos(neighbor_threshold / 180 * np.pi)).float()
+ voting = cos * voting_mask # (bmm_size, rot_candidate_num)
+ counts += torch.sum(voting * expanded_confs[i * bmm_size:(i + 1) * bmm_size], dim=0)
+
+ if not multi_candidate:
+ direction = np.array(sphere_pts[np.argmax(counts.cpu().numpy())])
+ else:
+ counts_list = counts.cpu().numpy()
+ normalized_counts_list = counts_list / np.max(counts_list)
+ candidates = sphere_pts[normalized_counts_list >= candidate_threshold]
+ candidate_weights = normalized_counts_list[normalized_counts_list >= candidate_threshold]
+ candidate_weights = candidate_weights / np.sum(candidate_weights)
+ if not rotation_cluster:
+ direction = np.sum(candidates * candidate_weights[:, None], axis=0)
+ direction /= np.linalg.norm(direction)
+ else:
+ if candidates.shape[0] == 1:
+ direction = candidates[0]
+ else:
+ kmeans.fit(candidates)
+ candidate_center1 = kmeans.cluster_centers_[0]
+ candidate_center2 = kmeans.cluster_centers_[1]
+ cluster_cos_theta = np.dot(candidate_center1, candidate_center2)
+ cluster_cos_theta = np.clip(cluster_cos_theta, -1., 1.)
+ cluster_theta = np.arccos(cluster_cos_theta)
+ if cluster_theta > np.pi/2:
+ candidate_clusters = kmeans.labels_
+ clusters_num = np.bincount(candidate_clusters)
+ if clusters_num[0] == clusters_num[1]:
+ candidate_weights1 = candidate_weights[candidate_clusters == 0]
+ candidate_weights2 = candidate_weights[candidate_clusters == 1]
+ if np.sum(candidate_weights1) >= np.sum(candidate_weights2):
+ candidates = candidates[candidate_clusters == 0]
+ candidate_weights = candidate_weights[candidate_clusters == 0]
+ candidate_weights = candidate_weights / np.sum(candidate_weights)
+ direction = np.sum(candidates * candidate_weights[:, None], axis=0)
+ direction /= np.linalg.norm(direction)
+ else:
+ candidates = candidates[candidate_clusters == 1]
+ candidate_weights = candidate_weights[candidate_clusters == 1]
+ candidate_weights = candidate_weights / np.sum(candidate_weights)
+ direction = np.sum(candidates * candidate_weights[:, None], axis=0)
+ direction /= np.linalg.norm(direction)
+ else:
+ max_cluster = np.bincount(candidate_clusters).argmax()
+ candidates = candidates[candidate_clusters == max_cluster]
+ candidate_weights = candidate_weights[candidate_clusters == max_cluster]
+ candidate_weights = candidate_weights / np.sum(candidate_weights)
+ direction = np.sum(candidates * candidate_weights[:, None], axis=0)
+ direction /= np.linalg.norm(direction)
+ else:
+ direction = np.sum(candidates * candidate_weights[:, None], axis=0)
+ direction /= np.linalg.norm(direction)
+
+ return (direction, sphere_pts, counts.cpu().numpy())
+
+
+def inference_fn(pc:np.ndarray, pc_color:Optional[np.ndarray], shot_encoder:nn.Module, encoder:nn.Module,
+ denoise:bool, normalize:str, resolution:float, receptive_field:int, sample_points_num:int, sample_tuples_num:int, tuple_more_num:int,
+ voting_num:int, rot_bin_num:int, angle_tol:float,
+ translation2pc:bool, multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool,
+ rotation_multi_neighbor:bool, neighbor_threshold:float, bmm_size:int, joint_num:int, device:int) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
+ if not hasattr(inference_fn, 'permutations'):
+ inference_fn.permutations = list(itertools.permutations(range(sample_points_num), 2))
+ inference_fn.sample_points_num = sample_points_num
+ else:
+ if inference_fn.sample_points_num != sample_points_num:
+ inference_fn.permutations = list(itertools.permutations(range(sample_points_num), 2))
+ inference_fn.sample_points_num = sample_points_num
+ else:
+ pass
+ if rotation_cluster:
+ # kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto')
+ kmeans = KMeans(n_clusters=2, init='k-means++', n_init=1)
+ else:
+ kmeans = None
+ rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi))
+ has_rgb = pc_color is not None
+
+ # preprocess point cloud
+ if denoise:
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ _, index = pcd.remove_statistical_outlier(nb_neighbors=100, std_ratio=1.5)
+ pc = pc[index]
+ if has_rgb:
+ pc_color = pc_color[index]
+
+ pc, center, scale = pc_normalize(pc, normalize)
+
+ indices = ME.utils.sparse_quantize(np.ascontiguousarray(pc), return_index=True, quantization_size=resolution)[1]
+ pc = np.ascontiguousarray(pc[indices].astype(np.float32))
+ if has_rgb:
+ pc_color = pc_color[indices]
+
+ pc_normal = shot.estimate_normal(pc, resolution * receptive_field).reshape(-1, 3).astype(np.float32)
+ pc_normal[~np.isfinite(pc_normal)] = 0
+
+ pc_shot = shot.compute(pc, resolution * receptive_field, resolution * receptive_field).reshape(-1, 352).astype(np.float32)
+ pc_shot[~np.isfinite(pc_shot)] = 0
+
+ pc, indices = farthest_point_sample(pc, sample_points_num)
+ pc_normal = pc_normal[indices]
+ pc_shot = pc_shot[indices]
+ if has_rgb:
+ pc_color = pc_color[indices]
+
+ point_idxs = random.sample(inference_fn.permutations, sample_tuples_num)
+ point_idxs = np.array(point_idxs, dtype=np.int64)
+ point_idxs_more = np.random.randint(0, sample_points_num, size=(sample_tuples_num, tuple_more_num), dtype=np.int64)
+ point_idxs_all = np.concatenate([point_idxs, point_idxs_more], axis=-1)
+
+ pcs = torch.from_numpy(pc)[None, ...].cuda(device)
+ pc_normals = torch.from_numpy(pc_normal)[None, ...].cuda(device)
+ pc_shots = torch.from_numpy(pc_shot)[None, ...].cuda(device)
+ if has_rgb:
+ pc_colors = torch.from_numpy(pc_color)[None, ...].cuda(device)
+ point_idxs_alls = torch.from_numpy(point_idxs_all)[None, ...].cuda(device)
+
+ # inference
+ with torch.no_grad():
+ shot_feat = shot_encoder(pc_shots) # (1, N, N_s)
+
+ shot_inputs = torch.cat([
+ torch.gather(shot_feat, 1,
+ point_idxs_alls[:, :, i:i+1].expand(
+ (1, sample_tuples_num, shot_feat.shape[-1])))
+ for i in range(point_idxs_alls.shape[-1])], dim=-1) # (1, N_t, N_s * (2 + N_m))
+ normal_inputs = torch.cat([torch.max(
+ torch.sum(torch.gather(pc_normals, 1,
+ point_idxs_alls[:, :, i:i+1].expand(
+ (1, sample_tuples_num, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_alls[:, :, j:j+1].expand(
+ (1, sample_tuples_num, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True),
+ torch.sum(-torch.gather(pc_normals, 1,
+ point_idxs_alls[:, :, i:i+1].expand(
+ (1, sample_tuples_num, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_alls[:, :, j:j+1].expand(
+ (1, sample_tuples_num, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True))
+ for (i, j) in combinations(np.arange(point_idxs_alls.shape[-1]), 2)], dim=-1) # (1, N_t, (2+N_m \choose 2))
+ coord_inputs = torch.cat([
+ torch.gather(pcs, 1,
+ point_idxs_alls[:, :, i:i+1].expand(
+ (1, sample_tuples_num, pcs.shape[-1]))) -
+ torch.gather(pcs, 1,
+ point_idxs_alls[:, :, j:j+1].expand(
+ (1, sample_tuples_num, pcs.shape[-1])))
+ for (i, j) in combinations(np.arange(point_idxs_alls.shape[-1]), 2)], dim=-1) # (1, N_t, 3 * (2+N_m \choose 2))
+ if has_rgb:
+ rgb_inputs = torch.cat([
+ torch.gather(pc_colors, 1,
+ point_idxs_alls[:, :, i:i+1].expand(
+ (1, sample_tuples_num, pc_colors.shape[-1])))
+ for i in range(point_idxs_alls.shape[-1])], dim=-1) # (1, N_t, 3 * (2 + N_m))
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1)
+ else:
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1)
+ preds = encoder(inputs) # (1, N_t, (2 + N_r + 2 + 1) * J)
+ pred = preds.cpu().numpy().astype(np.float32)[0] # (N_t, (2 + N_r + 2 + 1) * J)
+ pred_tensor = torch.from_numpy(pred)
+
+ pred_translations, pred_rotations, pred_affordances = [], [], []
+ for j in range(joint_num):
+ # conf selection
+ pred_conf = torch.sigmoid(pred_tensor[:, -1*joint_num+j]) # (N_t,)
+ not_selected_indices = pred_conf < 0.5
+ pred_conf[not_selected_indices] = 0
+ # pred_conf[pred_conf > 0] = 1
+ pred_conf = pred_conf.numpy()
+
+ # translation voting
+ pred_tr = pred[:, 2*j:2*(j+1)] # (N_t, 2)
+ pred_translation, grid_obj, corners = voting_translation(pc, pred_tr, point_idxs, pred_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_translations.append(pred_translation)
+
+ # rotation voting
+ pred_rot = pred_tensor[:, (2*joint_num+rot_bin_num*j):(2*joint_num+rot_bin_num*(j+1))] # (N_t, rot_bin_num)
+ pred_rot = torch.softmax(pred_rot, dim=-1)
+ pred_rot = torch.multinomial(pred_rot, 1).float()[:, 0] # (N_t,)
+ pred_rot = pred_rot / (rot_bin_num - 1) * np.pi
+ pred_rot = pred_rot.numpy()
+ pred_direction, sphere_pts, counts = voting_rotation(pc, pred_rot, point_idxs, pred_conf,
+ rot_candidate_num, angle_tol, voting_num, bmm_size, device,
+ multi_candidate, candidate_threshold, rotation_cluster, kmeans,
+ rotation_multi_neighbor, neighbor_threshold)
+ pred_rotations.append(pred_direction)
+
+ # affordance voting
+ pred_afford = pred[:, (2*joint_num+rot_bin_num*joint_num+2*j):(2*joint_num+rot_bin_num*joint_num+2*(j+1))] # (N_t, 2)
+ pred_affordance, agrid_obj, acorners = voting_translation(pc, pred_afford, point_idxs, pred_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_affordances.append(pred_affordance)
+ return (pred_translations, pred_rotations, pred_affordances)
+
+
+if __name__ == '__main__':
+ pass
diff --git a/license/.gitkeep b/license/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/models/helper_math.cuh b/models/helper_math.cuh
new file mode 100644
index 0000000..4597d8a
--- /dev/null
+++ b/models/helper_math.cuh
@@ -0,0 +1,1450 @@
+#pragma once
+/**
+* Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
+*
+* Please refer to the NVIDIA end user license agreement (EULA) associated
+* with this source code for terms and conditions that govern your use of
+* this software. Any use, reproduction, disclosure, or distribution of
+* this software and related documentation outside the terms of the EULA
+* is strictly prohibited.
+*
+*/
+
+/*
+* This file implements common mathematical operations on vector types
+* (float3, float4 etc.) since these are not provided as standard by CUDA.
+*
+* The syntax is modeled on the Cg standard library.
+*
+* This is part of the Helper library includes
+*
+* Thanks to Linh Hah for additions and fixes.
+*/
+
+#ifndef HELPER_MATH_H
+#define HELPER_MATH_H
+
+#include "/usr/local/cuda/include/cuda_runtime.h"
+
+typedef unsigned int uint;
+typedef unsigned short ushort;
+
+#ifndef __CUDACC__
+#include
+
+////////////////////////////////////////////////////////////////////////////////
+// host implementations of CUDA functions
+////////////////////////////////////////////////////////////////////////////////
+
+inline float fminf(float a, float b)
+{
+ return a < b ? a : b;
+}
+
+inline float fmaxf(float a, float b)
+{
+ return a > b ? a : b;
+}
+
+inline int max(int a, int b)
+{
+ return a > b ? a : b;
+}
+
+inline int min(int a, int b)
+{
+ return a < b ? a : b;
+}
+
+inline float rsqrtf(float x)
+{
+ return 1.0f / sqrtf(x);
+}
+#endif
+
+////////////////////////////////////////////////////////////////////////////////
+// constructors
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 make_float2(float s)
+{
+ return make_float2(s, s);
+}
+inline __host__ __device__ float2 make_float2(float3 a)
+{
+ return make_float2(a.x, a.y);
+}
+inline __host__ __device__ float2 make_float2(int2 a)
+{
+ return make_float2(float(a.x), float(a.y));
+}
+inline __host__ __device__ float2 make_float2(uint2 a)
+{
+ return make_float2(float(a.x), float(a.y));
+}
+
+inline __host__ __device__ int2 make_int2(int s)
+{
+ return make_int2(s, s);
+}
+inline __host__ __device__ int2 make_int2(int3 a)
+{
+ return make_int2(a.x, a.y);
+}
+inline __host__ __device__ int2 make_int2(uint2 a)
+{
+ return make_int2(int(a.x), int(a.y));
+}
+inline __host__ __device__ int2 make_int2(float2 a)
+{
+ return make_int2(int(a.x), int(a.y));
+}
+
+inline __host__ __device__ uint2 make_uint2(uint s)
+{
+ return make_uint2(s, s);
+}
+inline __host__ __device__ uint2 make_uint2(uint3 a)
+{
+ return make_uint2(a.x, a.y);
+}
+inline __host__ __device__ uint2 make_uint2(int2 a)
+{
+ return make_uint2(uint(a.x), uint(a.y));
+}
+
+inline __host__ __device__ float3 make_float3(float s)
+{
+ return make_float3(s, s, s);
+}
+inline __host__ __device__ float3 make_float3(float2 a)
+{
+ return make_float3(a.x, a.y, 0.0f);
+}
+inline __host__ __device__ float3 make_float3(float2 a, float s)
+{
+ return make_float3(a.x, a.y, s);
+}
+inline __host__ __device__ float3 make_float3(float4 a)
+{
+ return make_float3(a.x, a.y, a.z);
+}
+inline __host__ __device__ float3 make_float3(int3 a)
+{
+ return make_float3(float(a.x), float(a.y), float(a.z));
+}
+inline __host__ __device__ float3 make_float3(uint3 a)
+{
+ return make_float3(float(a.x), float(a.y), float(a.z));
+}
+
+inline __host__ __device__ int3 make_int3(int s)
+{
+ return make_int3(s, s, s);
+}
+inline __host__ __device__ int3 make_int3(int2 a)
+{
+ return make_int3(a.x, a.y, 0);
+}
+inline __host__ __device__ int3 make_int3(int2 a, int s)
+{
+ return make_int3(a.x, a.y, s);
+}
+inline __host__ __device__ int3 make_int3(uint3 a)
+{
+ return make_int3(int(a.x), int(a.y), int(a.z));
+}
+inline __host__ __device__ int3 make_int3(float3 a)
+{
+ return make_int3(int(a.x), int(a.y), int(a.z));
+}
+
+inline __host__ __device__ uint3 make_uint3(uint s)
+{
+ return make_uint3(s, s, s);
+}
+inline __host__ __device__ uint3 make_uint3(uint2 a)
+{
+ return make_uint3(a.x, a.y, 0);
+}
+inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
+{
+ return make_uint3(a.x, a.y, s);
+}
+inline __host__ __device__ uint3 make_uint3(uint4 a)
+{
+ return make_uint3(a.x, a.y, a.z);
+}
+inline __host__ __device__ uint3 make_uint3(int3 a)
+{
+ return make_uint3(uint(a.x), uint(a.y), uint(a.z));
+}
+
+inline __host__ __device__ float4 make_float4(float s)
+{
+ return make_float4(s, s, s, s);
+}
+inline __host__ __device__ float4 make_float4(float3 a)
+{
+ return make_float4(a.x, a.y, a.z, 0.0f);
+}
+inline __host__ __device__ float4 make_float4(float3 a, float w)
+{
+ return make_float4(a.x, a.y, a.z, w);
+}
+inline __host__ __device__ float4 make_float4(int4 a)
+{
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
+}
+inline __host__ __device__ float4 make_float4(uint4 a)
+{
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
+}
+
+inline __host__ __device__ int4 make_int4(int s)
+{
+ return make_int4(s, s, s, s);
+}
+inline __host__ __device__ int4 make_int4(int3 a)
+{
+ return make_int4(a.x, a.y, a.z, 0);
+}
+inline __host__ __device__ int4 make_int4(int3 a, int w)
+{
+ return make_int4(a.x, a.y, a.z, w);
+}
+inline __host__ __device__ int4 make_int4(uint4 a)
+{
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
+}
+inline __host__ __device__ int4 make_int4(float4 a)
+{
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
+}
+
+
+inline __host__ __device__ uint4 make_uint4(uint s)
+{
+ return make_uint4(s, s, s, s);
+}
+inline __host__ __device__ uint4 make_uint4(uint3 a)
+{
+ return make_uint4(a.x, a.y, a.z, 0);
+}
+inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
+{
+ return make_uint4(a.x, a.y, a.z, w);
+}
+inline __host__ __device__ uint4 make_uint4(int4 a)
+{
+ return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// negate
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 operator-(float2 &a)
+{
+ return make_float2(-a.x, -a.y);
+}
+inline __host__ __device__ int2 operator-(int2 &a)
+{
+ return make_int2(-a.x, -a.y);
+}
+inline __host__ __device__ float3 operator-(float3 &a)
+{
+ return make_float3(-a.x, -a.y, -a.z);
+}
+inline __host__ __device__ int3 operator-(int3 &a)
+{
+ return make_int3(-a.x, -a.y, -a.z);
+}
+inline __host__ __device__ float4 operator-(float4 &a)
+{
+ return make_float4(-a.x, -a.y, -a.z, -a.w);
+}
+inline __host__ __device__ int4 operator-(int4 &a)
+{
+ return make_int4(-a.x, -a.y, -a.z, -a.w);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// addition
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 operator+(float2 a, float2 b)
+{
+ return make_float2(a.x + b.x, a.y + b.y);
+}
+inline __host__ __device__ void operator+=(float2 &a, float2 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+}
+inline __host__ __device__ float2 operator+(float2 a, float b)
+{
+ return make_float2(a.x + b, a.y + b);
+}
+inline __host__ __device__ float2 operator+(float b, float2 a)
+{
+ return make_float2(a.x + b, a.y + b);
+}
+inline __host__ __device__ void operator+=(float2 &a, float b)
+{
+ a.x += b;
+ a.y += b;
+}
+
+inline __host__ __device__ int2 operator+(int2 a, int2 b)
+{
+ return make_int2(a.x + b.x, a.y + b.y);
+}
+inline __host__ __device__ void operator+=(int2 &a, int2 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+}
+inline __host__ __device__ int2 operator+(int2 a, int b)
+{
+ return make_int2(a.x + b, a.y + b);
+}
+inline __host__ __device__ int2 operator+(int b, int2 a)
+{
+ return make_int2(a.x + b, a.y + b);
+}
+inline __host__ __device__ void operator+=(int2 &a, int b)
+{
+ a.x += b;
+ a.y += b;
+}
+
+inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
+{
+ return make_uint2(a.x + b.x, a.y + b.y);
+}
+inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+}
+inline __host__ __device__ uint2 operator+(uint2 a, uint b)
+{
+ return make_uint2(a.x + b, a.y + b);
+}
+inline __host__ __device__ uint2 operator+(uint b, uint2 a)
+{
+ return make_uint2(a.x + b, a.y + b);
+}
+inline __host__ __device__ void operator+=(uint2 &a, uint b)
+{
+ a.x += b;
+ a.y += b;
+}
+
+
+inline __host__ __device__ float3 operator+(float3 a, float3 b)
+{
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+inline __host__ __device__ void operator+=(float3 &a, float3 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+ a.z += b.z;
+}
+inline __host__ __device__ float3 operator+(float3 a, float b)
+{
+ return make_float3(a.x + b, a.y + b, a.z + b);
+}
+inline __host__ __device__ void operator+=(float3 &a, float b)
+{
+ a.x += b;
+ a.y += b;
+ a.z += b;
+}
+
+inline __host__ __device__ int3 operator+(int3 a, int3 b)
+{
+ return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+inline __host__ __device__ void operator+=(int3 &a, int3 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+ a.z += b.z;
+}
+inline __host__ __device__ int3 operator+(int3 a, int b)
+{
+ return make_int3(a.x + b, a.y + b, a.z + b);
+}
+inline __host__ __device__ void operator+=(int3 &a, int b)
+{
+ a.x += b;
+ a.y += b;
+ a.z += b;
+}
+
+inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
+{
+ return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+ a.z += b.z;
+}
+inline __host__ __device__ uint3 operator+(uint3 a, uint b)
+{
+ return make_uint3(a.x + b, a.y + b, a.z + b);
+}
+inline __host__ __device__ void operator+=(uint3 &a, uint b)
+{
+ a.x += b;
+ a.y += b;
+ a.z += b;
+}
+
+inline __host__ __device__ int3 operator+(int b, int3 a)
+{
+ return make_int3(a.x + b, a.y + b, a.z + b);
+}
+inline __host__ __device__ uint3 operator+(uint b, uint3 a)
+{
+ return make_uint3(a.x + b, a.y + b, a.z + b);
+}
+inline __host__ __device__ float3 operator+(float b, float3 a)
+{
+ return make_float3(a.x + b, a.y + b, a.z + b);
+}
+
+inline __host__ __device__ float4 operator+(float4 a, float4 b)
+{
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+}
+inline __host__ __device__ void operator+=(float4 &a, float4 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+ a.z += b.z;
+ a.w += b.w;
+}
+inline __host__ __device__ float4 operator+(float4 a, float b)
+{
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
+}
+inline __host__ __device__ float4 operator+(float b, float4 a)
+{
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
+}
+inline __host__ __device__ void operator+=(float4 &a, float b)
+{
+ a.x += b;
+ a.y += b;
+ a.z += b;
+ a.w += b;
+}
+
+inline __host__ __device__ int4 operator+(int4 a, int4 b)
+{
+ return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+}
+inline __host__ __device__ void operator+=(int4 &a, int4 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+ a.z += b.z;
+ a.w += b.w;
+}
+inline __host__ __device__ int4 operator+(int4 a, int b)
+{
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
+}
+inline __host__ __device__ int4 operator+(int b, int4 a)
+{
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
+}
+inline __host__ __device__ void operator+=(int4 &a, int b)
+{
+ a.x += b;
+ a.y += b;
+ a.z += b;
+ a.w += b;
+}
+
+inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
+{
+ return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+}
+inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
+{
+ a.x += b.x;
+ a.y += b.y;
+ a.z += b.z;
+ a.w += b.w;
+}
+inline __host__ __device__ uint4 operator+(uint4 a, uint b)
+{
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
+}
+inline __host__ __device__ uint4 operator+(uint b, uint4 a)
+{
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
+}
+inline __host__ __device__ void operator+=(uint4 &a, uint b)
+{
+ a.x += b;
+ a.y += b;
+ a.z += b;
+ a.w += b;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// subtract
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 operator-(float2 a, float2 b)
+{
+ return make_float2(a.x - b.x, a.y - b.y);
+}
+inline __host__ __device__ void operator-=(float2 &a, float2 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+}
+inline __host__ __device__ float2 operator-(float2 a, float b)
+{
+ return make_float2(a.x - b, a.y - b);
+}
+inline __host__ __device__ float2 operator-(float b, float2 a)
+{
+ return make_float2(b - a.x, b - a.y);
+}
+inline __host__ __device__ void operator-=(float2 &a, float b)
+{
+ a.x -= b;
+ a.y -= b;
+}
+
+inline __host__ __device__ int2 operator-(int2 a, int2 b)
+{
+ return make_int2(a.x - b.x, a.y - b.y);
+}
+inline __host__ __device__ void operator-=(int2 &a, int2 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+}
+inline __host__ __device__ int2 operator-(int2 a, int b)
+{
+ return make_int2(a.x - b, a.y - b);
+}
+inline __host__ __device__ int2 operator-(int b, int2 a)
+{
+ return make_int2(b - a.x, b - a.y);
+}
+inline __host__ __device__ void operator-=(int2 &a, int b)
+{
+ a.x -= b;
+ a.y -= b;
+}
+
+inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
+{
+ return make_uint2(a.x - b.x, a.y - b.y);
+}
+inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+}
+inline __host__ __device__ uint2 operator-(uint2 a, uint b)
+{
+ return make_uint2(a.x - b, a.y - b);
+}
+inline __host__ __device__ uint2 operator-(uint b, uint2 a)
+{
+ return make_uint2(b - a.x, b - a.y);
+}
+inline __host__ __device__ void operator-=(uint2 &a, uint b)
+{
+ a.x -= b;
+ a.y -= b;
+}
+
+inline __host__ __device__ float3 operator-(float3 a, float3 b)
+{
+ return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+inline __host__ __device__ void operator-=(float3 &a, float3 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+ a.z -= b.z;
+}
+inline __host__ __device__ float3 operator-(float3 a, float b)
+{
+ return make_float3(a.x - b, a.y - b, a.z - b);
+}
+inline __host__ __device__ float3 operator-(float b, float3 a)
+{
+ return make_float3(b - a.x, b - a.y, b - a.z);
+}
+inline __host__ __device__ void operator-=(float3 &a, float b)
+{
+ a.x -= b;
+ a.y -= b;
+ a.z -= b;
+}
+
+inline __host__ __device__ int3 operator-(int3 a, int3 b)
+{
+ return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+inline __host__ __device__ void operator-=(int3 &a, int3 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+ a.z -= b.z;
+}
+inline __host__ __device__ int3 operator-(int3 a, int b)
+{
+ return make_int3(a.x - b, a.y - b, a.z - b);
+}
+inline __host__ __device__ int3 operator-(int b, int3 a)
+{
+ return make_int3(b - a.x, b - a.y, b - a.z);
+}
+inline __host__ __device__ void operator-=(int3 &a, int b)
+{
+ a.x -= b;
+ a.y -= b;
+ a.z -= b;
+}
+
+inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
+{
+ return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+ a.z -= b.z;
+}
+inline __host__ __device__ uint3 operator-(uint3 a, uint b)
+{
+ return make_uint3(a.x - b, a.y - b, a.z - b);
+}
+inline __host__ __device__ uint3 operator-(uint b, uint3 a)
+{
+ return make_uint3(b - a.x, b - a.y, b - a.z);
+}
+inline __host__ __device__ void operator-=(uint3 &a, uint b)
+{
+ a.x -= b;
+ a.y -= b;
+ a.z -= b;
+}
+
+inline __host__ __device__ float4 operator-(float4 a, float4 b)
+{
+ return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
+}
+inline __host__ __device__ void operator-=(float4 &a, float4 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+ a.z -= b.z;
+ a.w -= b.w;
+}
+inline __host__ __device__ float4 operator-(float4 a, float b)
+{
+ return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
+}
+inline __host__ __device__ void operator-=(float4 &a, float b)
+{
+ a.x -= b;
+ a.y -= b;
+ a.z -= b;
+ a.w -= b;
+}
+
+inline __host__ __device__ int4 operator-(int4 a, int4 b)
+{
+ return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
+}
+inline __host__ __device__ void operator-=(int4 &a, int4 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+ a.z -= b.z;
+ a.w -= b.w;
+}
+inline __host__ __device__ int4 operator-(int4 a, int b)
+{
+ return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
+}
+inline __host__ __device__ int4 operator-(int b, int4 a)
+{
+ return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
+}
+inline __host__ __device__ void operator-=(int4 &a, int b)
+{
+ a.x -= b;
+ a.y -= b;
+ a.z -= b;
+ a.w -= b;
+}
+
+inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
+{
+ return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
+}
+inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
+{
+ a.x -= b.x;
+ a.y -= b.y;
+ a.z -= b.z;
+ a.w -= b.w;
+}
+inline __host__ __device__ uint4 operator-(uint4 a, uint b)
+{
+ return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
+}
+inline __host__ __device__ uint4 operator-(uint b, uint4 a)
+{
+ return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
+}
+inline __host__ __device__ void operator-=(uint4 &a, uint b)
+{
+ a.x -= b;
+ a.y -= b;
+ a.z -= b;
+ a.w -= b;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// multiply
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 operator*(float2 a, float2 b)
+{
+ return make_float2(a.x * b.x, a.y * b.y);
+}
+inline __host__ __device__ void operator*=(float2 &a, float2 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+}
+inline __host__ __device__ float2 operator*(float2 a, float b)
+{
+ return make_float2(a.x * b, a.y * b);
+}
+inline __host__ __device__ float2 operator*(float b, float2 a)
+{
+ return make_float2(b * a.x, b * a.y);
+}
+inline __host__ __device__ void operator*=(float2 &a, float b)
+{
+ a.x *= b;
+ a.y *= b;
+}
+
+inline __host__ __device__ int2 operator*(int2 a, int2 b)
+{
+ return make_int2(a.x * b.x, a.y * b.y);
+}
+inline __host__ __device__ void operator*=(int2 &a, int2 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+}
+inline __host__ __device__ int2 operator*(int2 a, int b)
+{
+ return make_int2(a.x * b, a.y * b);
+}
+inline __host__ __device__ int2 operator*(int b, int2 a)
+{
+ return make_int2(b * a.x, b * a.y);
+}
+inline __host__ __device__ void operator*=(int2 &a, int b)
+{
+ a.x *= b;
+ a.y *= b;
+}
+
+inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
+{
+ return make_uint2(a.x * b.x, a.y * b.y);
+}
+inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+}
+inline __host__ __device__ uint2 operator*(uint2 a, uint b)
+{
+ return make_uint2(a.x * b, a.y * b);
+}
+inline __host__ __device__ uint2 operator*(uint b, uint2 a)
+{
+ return make_uint2(b * a.x, b * a.y);
+}
+inline __host__ __device__ void operator*=(uint2 &a, uint b)
+{
+ a.x *= b;
+ a.y *= b;
+}
+
+inline __host__ __device__ float3 operator*(float3 a, float3 b)
+{
+ return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+inline __host__ __device__ void operator*=(float3 &a, float3 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+ a.z *= b.z;
+}
+inline __host__ __device__ float3 operator*(float3 a, float b)
+{
+ return make_float3(a.x * b, a.y * b, a.z * b);
+}
+inline __host__ __device__ float3 operator*(float b, float3 a)
+{
+ return make_float3(b * a.x, b * a.y, b * a.z);
+}
+inline __host__ __device__ void operator*=(float3 &a, float b)
+{
+ a.x *= b;
+ a.y *= b;
+ a.z *= b;
+}
+
+inline __host__ __device__ int3 operator*(int3 a, int3 b)
+{
+ return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+inline __host__ __device__ void operator*=(int3 &a, int3 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+ a.z *= b.z;
+}
+inline __host__ __device__ int3 operator*(int3 a, int b)
+{
+ return make_int3(a.x * b, a.y * b, a.z * b);
+}
+inline __host__ __device__ int3 operator*(int b, int3 a)
+{
+ return make_int3(b * a.x, b * a.y, b * a.z);
+}
+inline __host__ __device__ void operator*=(int3 &a, int b)
+{
+ a.x *= b;
+ a.y *= b;
+ a.z *= b;
+}
+
+inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
+{
+ return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+ a.z *= b.z;
+}
+inline __host__ __device__ uint3 operator*(uint3 a, uint b)
+{
+ return make_uint3(a.x * b, a.y * b, a.z * b);
+}
+inline __host__ __device__ uint3 operator*(uint b, uint3 a)
+{
+ return make_uint3(b * a.x, b * a.y, b * a.z);
+}
+inline __host__ __device__ void operator*=(uint3 &a, uint b)
+{
+ a.x *= b;
+ a.y *= b;
+ a.z *= b;
+}
+
+inline __host__ __device__ float4 operator*(float4 a, float4 b)
+{
+ return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+inline __host__ __device__ void operator*=(float4 &a, float4 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+ a.z *= b.z;
+ a.w *= b.w;
+}
+inline __host__ __device__ float4 operator*(float4 a, float b)
+{
+ return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
+}
+inline __host__ __device__ float4 operator*(float b, float4 a)
+{
+ return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
+}
+inline __host__ __device__ void operator*=(float4 &a, float b)
+{
+ a.x *= b;
+ a.y *= b;
+ a.z *= b;
+ a.w *= b;
+}
+
+inline __host__ __device__ int4 operator*(int4 a, int4 b)
+{
+ return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+inline __host__ __device__ void operator*=(int4 &a, int4 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+ a.z *= b.z;
+ a.w *= b.w;
+}
+inline __host__ __device__ int4 operator*(int4 a, int b)
+{
+ return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
+}
+inline __host__ __device__ int4 operator*(int b, int4 a)
+{
+ return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
+}
+inline __host__ __device__ void operator*=(int4 &a, int b)
+{
+ a.x *= b;
+ a.y *= b;
+ a.z *= b;
+ a.w *= b;
+}
+
+inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
+{
+ return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
+{
+ a.x *= b.x;
+ a.y *= b.y;
+ a.z *= b.z;
+ a.w *= b.w;
+}
+inline __host__ __device__ uint4 operator*(uint4 a, uint b)
+{
+ return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
+}
+inline __host__ __device__ uint4 operator*(uint b, uint4 a)
+{
+ return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
+}
+inline __host__ __device__ void operator*=(uint4 &a, uint b)
+{
+ a.x *= b;
+ a.y *= b;
+ a.z *= b;
+ a.w *= b;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// divide
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 operator/(float2 a, float2 b)
+{
+ return make_float2(a.x / b.x, a.y / b.y);
+}
+inline __host__ __device__ void operator/=(float2 &a, float2 b)
+{
+ a.x /= b.x;
+ a.y /= b.y;
+}
+inline __host__ __device__ float2 operator/(float2 a, float b)
+{
+ return make_float2(a.x / b, a.y / b);
+}
+inline __host__ __device__ void operator/=(float2 &a, float b)
+{
+ a.x /= b;
+ a.y /= b;
+}
+inline __host__ __device__ float2 operator/(float b, float2 a)
+{
+ return make_float2(b / a.x, b / a.y);
+}
+
+inline __host__ __device__ float3 operator/(float3 a, float3 b)
+{
+ return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
+}
+inline __host__ __device__ void operator/=(float3 &a, float3 b)
+{
+ a.x /= b.x;
+ a.y /= b.y;
+ a.z /= b.z;
+}
+inline __host__ __device__ float3 operator/(float3 a, float b)
+{
+ return make_float3(a.x / b, a.y / b, a.z / b);
+}
+inline __host__ __device__ void operator/=(float3 &a, float b)
+{
+ a.x /= b;
+ a.y /= b;
+ a.z /= b;
+}
+inline __host__ __device__ float3 operator/(float b, float3 a)
+{
+ return make_float3(b / a.x, b / a.y, b / a.z);
+}
+
+inline __host__ __device__ float4 operator/(float4 a, float4 b)
+{
+ return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
+}
+inline __host__ __device__ void operator/=(float4 &a, float4 b)
+{
+ a.x /= b.x;
+ a.y /= b.y;
+ a.z /= b.z;
+ a.w /= b.w;
+}
+inline __host__ __device__ float4 operator/(float4 a, float b)
+{
+ return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
+}
+inline __host__ __device__ void operator/=(float4 &a, float b)
+{
+ a.x /= b;
+ a.y /= b;
+ a.z /= b;
+ a.w /= b;
+}
+inline __host__ __device__ float4 operator/(float b, float4 a)
+{
+ return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// min
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 fminf(float2 a, float2 b)
+{
+ return make_float2(fminf(a.x, b.x), fminf(a.y, b.y));
+}
+inline __host__ __device__ float3 fminf(float3 a, float3 b)
+{
+ return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z));
+}
+inline __host__ __device__ float4 fminf(float4 a, float4 b)
+{
+ return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));
+}
+
+inline __host__ __device__ int2 min(int2 a, int2 b)
+{
+ return make_int2(min(a.x, b.x), min(a.y, b.y));
+}
+inline __host__ __device__ int3 min(int3 a, int3 b)
+{
+ return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));
+}
+inline __host__ __device__ int4 min(int4 a, int4 b)
+{
+ return make_int4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w));
+}
+
+inline __host__ __device__ uint2 min(uint2 a, uint2 b)
+{
+ return make_uint2(min(a.x, b.x), min(a.y, b.y));
+}
+inline __host__ __device__ uint3 min(uint3 a, uint3 b)
+{
+ return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));
+}
+inline __host__ __device__ uint4 min(uint4 a, uint4 b)
+{
+ return make_uint4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// max
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
+{
+ return make_float2(fmaxf(a.x, b.x), fmaxf(a.y, b.y));
+}
+inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
+{
+ return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z));
+}
+inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
+{
+ return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));
+}
+
+inline __host__ __device__ int2 max(int2 a, int2 b)
+{
+ return make_int2(max(a.x, b.x), max(a.y, b.y));
+}
+inline __host__ __device__ int3 max(int3 a, int3 b)
+{
+ return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));
+}
+inline __host__ __device__ int4 max(int4 a, int4 b)
+{
+ return make_int4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w));
+}
+
+inline __host__ __device__ uint2 max(uint2 a, uint2 b)
+{
+ return make_uint2(max(a.x, b.x), max(a.y, b.y));
+}
+inline __host__ __device__ uint3 max(uint3 a, uint3 b)
+{
+ return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));
+}
+inline __host__ __device__ uint4 max(uint4 a, uint4 b)
+{
+ return make_uint4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// lerp
+// - linear interpolation between a and b, based on value t in [0, 1] range
+////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ __host__ float lerp(float a, float b, float t)
+{
+ return a + t*(b - a);
+}
+inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
+{
+ return a + t*(b - a);
+}
+inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
+{
+ return a + t*(b - a);
+}
+inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
+{
+ return a + t*(b - a);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// clamp
+// - clamp the value v to be in the range [a, b]
+////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ __host__ float clamp(float f, float a, float b)
+{
+ return fmaxf(a, fminf(f, b));
+}
+inline __device__ __host__ int clamp(int f, int a, int b)
+{
+ return max(a, min(f, b));
+}
+inline __device__ __host__ uint clamp(uint f, uint a, uint b)
+{
+ return max(a, min(f, b));
+}
+
+inline __device__ __host__ float2 clamp(float2 v, float a, float b)
+{
+ return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
+}
+inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
+{
+ return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
+}
+inline __device__ __host__ float3 clamp(float3 v, float a, float b)
+{
+ return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
+}
+inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
+{
+ return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
+}
+inline __device__ __host__ float4 clamp(float4 v, float a, float b)
+{
+ return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
+}
+inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
+{
+ return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
+}
+
+inline __device__ __host__ int2 clamp(int2 v, int a, int b)
+{
+ return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
+}
+inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
+{
+ return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
+}
+inline __device__ __host__ int3 clamp(int3 v, int a, int b)
+{
+ return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
+}
+inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
+{
+ return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
+}
+inline __device__ __host__ int4 clamp(int4 v, int a, int b)
+{
+ return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
+}
+inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
+{
+ return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
+}
+
+inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
+{
+ return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
+}
+inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
+{
+ return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
+}
+inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
+{
+ return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
+}
+inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
+{
+ return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
+}
+inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
+{
+ return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
+}
+inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
+{
+ return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// dot product
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float dot(float2 a, float2 b)
+{
+ return a.x * b.x + a.y * b.y;
+}
+inline __host__ __device__ float dot(float3 a, float3 b)
+{
+ return a.x * b.x + a.y * b.y + a.z * b.z;
+}
+inline __host__ __device__ float dot(float4 a, float4 b)
+{
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
+}
+
+inline __host__ __device__ int dot(int2 a, int2 b)
+{
+ return a.x * b.x + a.y * b.y;
+}
+inline __host__ __device__ int dot(int3 a, int3 b)
+{
+ return a.x * b.x + a.y * b.y + a.z * b.z;
+}
+inline __host__ __device__ int dot(int4 a, int4 b)
+{
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
+}
+
+inline __host__ __device__ uint dot(uint2 a, uint2 b)
+{
+ return a.x * b.x + a.y * b.y;
+}
+inline __host__ __device__ uint dot(uint3 a, uint3 b)
+{
+ return a.x * b.x + a.y * b.y + a.z * b.z;
+}
+inline __host__ __device__ uint dot(uint4 a, uint4 b)
+{
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// length
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float length(float2 v)
+{
+ return sqrtf(dot(v, v));
+}
+inline __host__ __device__ float length(float3 v)
+{
+ return sqrtf(dot(v, v));
+}
+inline __host__ __device__ float length(float4 v)
+{
+ return sqrtf(dot(v, v));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// normalize
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 normalize(float2 v)
+{
+ float invLen = rsqrtf(dot(v, v));
+ return v * invLen;
+}
+inline __host__ __device__ float3 normalize(float3 v)
+{
+ float invLen = rsqrtf(dot(v, v));
+ return v * invLen;
+}
+inline __host__ __device__ float4 normalize(float4 v)
+{
+ float invLen = rsqrtf(dot(v, v));
+ return v * invLen;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// floor
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 floorf(float2 v)
+{
+ return make_float2(floorf(v.x), floorf(v.y));
+}
+inline __host__ __device__ float3 floorf(float3 v)
+{
+ return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
+}
+inline __host__ __device__ float4 floorf(float4 v)
+{
+ return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// frac - returns the fractional portion of a scalar or each vector component
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float fracf(float v)
+{
+ return v - floorf(v);
+}
+inline __host__ __device__ float2 fracf(float2 v)
+{
+ return make_float2(fracf(v.x), fracf(v.y));
+}
+inline __host__ __device__ float3 fracf(float3 v)
+{
+ return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
+}
+inline __host__ __device__ float4 fracf(float4 v)
+{
+ return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// fmod
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 fmodf(float2 a, float2 b)
+{
+ return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
+}
+inline __host__ __device__ float3 fmodf(float3 a, float3 b)
+{
+ return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
+}
+inline __host__ __device__ float4 fmodf(float4 a, float4 b)
+{
+ return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// absolute value
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float2 fabs(float2 v)
+{
+ return make_float2(fabs(v.x), fabs(v.y));
+}
+inline __host__ __device__ float3 fabs(float3 v)
+{
+ return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
+}
+inline __host__ __device__ float4 fabs(float4 v)
+{
+ return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
+}
+
+inline __host__ __device__ int2 abs(int2 v)
+{
+ return make_int2(abs(v.x), abs(v.y));
+}
+inline __host__ __device__ int3 abs(int3 v)
+{
+ return make_int3(abs(v.x), abs(v.y), abs(v.z));
+}
+inline __host__ __device__ int4 abs(int4 v)
+{
+ return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// reflect
+// - returns reflection of incident ray I around surface normal N
+// - N should be normalized, reflected vector's length is equal to length of I
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float3 reflect(float3 i, float3 n)
+{
+ return i - 2.0f * n * dot(n, i);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// cross product
+////////////////////////////////////////////////////////////////////////////////
+
+inline __host__ __device__ float3 cross(float3 a, float3 b)
+{
+ return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// smoothstep
+// - returns 0 if x < a
+// - returns 1 if x > b
+// - otherwise returns smooth interpolation between 0 and 1 based on x
+////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ __host__ float smoothstep(float a, float b, float x)
+{
+ float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
+ return (y*y*(3.0f - (2.0f*y)));
+}
+inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
+{
+ float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
+ return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
+}
+inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
+{
+ float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
+ return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
+}
+inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
+{
+ float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
+ return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
+}
+
+#endif
\ No newline at end of file
diff --git a/models/roartnet.py b/models/roartnet.py
new file mode 100644
index 0000000..d96c329
--- /dev/null
+++ b/models/roartnet.py
@@ -0,0 +1,97 @@
+from typing import List
+from itertools import combinations
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResLayer(nn.Module):
+ def __init__(self, dim_in:int, dim_out:int, bn:bool=False, ln:bool=False, dropout:float=0.):
+ super().__init__()
+ self.is_bn = bn
+ self.is_ln = ln
+ self.fc1 = nn.Linear(dim_in, dim_out)
+ if bn:
+ self.bn1 = nn.BatchNorm1d(dim_out)
+ else:
+ self.bn1 = lambda x: x
+ if ln:
+ self.ln1 = nn.LayerNorm(dim_out)
+ else:
+ self.ln1 = lambda x: x
+ self.fc2 = nn.Linear(dim_out, dim_out)
+ if bn:
+ self.bn2 = nn.BatchNorm1d(dim_out)
+ else:
+ self.bn2 = lambda x: x
+ if ln:
+ self.ln2 = nn.LayerNorm(dim_out)
+ else:
+ self.ln2 = lambda x: x
+ if dim_in != dim_out:
+ self.fc0 = nn.Linear(dim_in, dim_out)
+ else:
+ self.fc0 = None
+ self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
+
+ def forward(self, x):
+ x_res = x if self.fc0 is None else self.fc0(x)
+ x = self.fc1(x)
+ if len(x.shape) > 3 or len(x.shape) < 2:
+ raise ValueError("x.shape should be (B, N, D) or (N, D)")
+ elif len(x.shape) == 3 and self.is_bn:
+ x = x.permute(0, 2, 1) # from (B, N, D) to (B, D, N)
+ x = self.bn1(x)
+ x = x.permute(0, 2, 1) # from (B, D, N) to (B, N, D)
+ elif len(x.shape) == 2 and self.is_bn:
+ x = self.bn1(x)
+ elif self.is_ln:
+ x = self.ln1(x)
+ else:
+ x = self.bn1(x) # actually self.bn1 is identity function
+ x = F.relu(x)
+
+ x = self.fc2(x)
+ if len(x.shape) > 3 or len(x.shape) < 2:
+ raise ValueError("x.shape should be (B, N, D) or (N, D)")
+ elif len(x.shape) == 3 and self.is_bn:
+ x = x.permute(0, 2, 1) # from (B, N, D) to (B, D, N)
+ x = self.bn2(x)
+ x = x.permute(0, 2, 1) # from (B, D, N) to (B, N, D)
+ elif len(x.shape) == 2 and self.is_bn:
+ x = self.bn2(x)
+ elif self.is_ln:
+ x = self.ln2(x)
+ else:
+ x = self.bn2(x) # actually self.bn2 is identity function
+ x = self.dropout(x + x_res)
+ return x
+
+
+def create_MLP(input_dim:int, hidden_dims:List[int], output_dim:int,
+ bn:bool, ln:bool, dropout:float) -> nn.Module:
+ fcs = hidden_dims
+ fcs.insert(0, input_dim)
+ fcs.append(output_dim)
+ MLP = nn.Sequential(
+ *[ResLayer(fcs[i], fcs[i+1], bn=bn, ln=ln, dropout=dropout)
+ for i in range(len(fcs) - 1)]
+ )
+ return MLP
+
+
+def create_shot_encoder(shot_hidden_dims:List[int], shot_feature_dim:int,
+ shot_bn:bool, shot_ln:bool, shot_dropout:float) -> nn.Module:
+ return create_MLP(input_dim=352, hidden_dims=shot_hidden_dims, output_dim=shot_feature_dim,
+ bn=shot_bn, ln=shot_ln, dropout=shot_dropout)
+
+def create_encoder(num_more:int, shot_feature_dim:int, has_rgb:bool,
+ overall_hidden_dims:List[int], rot_bin_num:int, joint_num:int,
+ overall_bn:bool, overall_ln:bool, overall_dropout:float) -> nn.Module:
+ # input order: (coords, normals, shots(, rgb))
+ overall_input_dim = len(list(combinations(np.arange(num_more + 2), 2))) * 4 + (num_more + 2) * shot_feature_dim + (3 * (num_more + 2) if has_rgb else 0)
+ # output order: (J*tr, J*rot, J*afford, J*conf)
+ overall_output_dim = (2 + rot_bin_num + 2 + 1) * joint_num
+
+ return create_MLP(input_dim=overall_input_dim, hidden_dims=overall_hidden_dims, output_dim=overall_output_dim,
+ bn=overall_bn, ln=overall_ln, dropout=overall_dropout)
diff --git a/models/voting.py b/models/voting.py
new file mode 100644
index 0000000..3fed7cf
--- /dev/null
+++ b/models/voting.py
@@ -0,0 +1,321 @@
+"""
+Modified from https://github.com/qq456cvb/CPPF/blob/main/models/voting.py
+"""
+import os
+import cupy as cp
+
+
+helper_math_path = os.path.join(os.path.dirname(__file__), 'helper_math.cuh')
+
+
+ppf_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r'''
+ #define M_PI 3.14159265358979323846264338327950288
+ extern "C" __global__
+ void ppf_voting(
+ const float *points, const float *outputs, const float *probs, const int *point_idxs, float *grid_obj, const float *corner, const float res,
+ int n_ppfs, int n_rots, int grid_x, int grid_y, int grid_z
+ ) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n_ppfs) {
+ float proj_len = outputs[idx * 2];
+ float odist = outputs[idx * 2 + 1];
+ if (odist < res) return;
+ int a_idx = point_idxs[idx * 2];
+ int b_idx = point_idxs[idx * 2 + 1];
+ float3 a = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]);
+ float3 b = make_float3(points[b_idx * 3], points[b_idx * 3 + 1], points[b_idx * 3 + 2]);
+ float3 ab = a - b;
+ if (length(ab) < 1e-7) return;
+ ab /= (length(ab) + 1e-7);
+ float3 c = a - ab * proj_len;
+
+ // float prob = max(probs[a_idx], probs[b_idx]);
+ float prob = probs[idx];
+ float3 co = make_float3(0.f, -ab.z, ab.y);
+ if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f);
+ float3 x = co / (length(co) + 1e-7) * odist;
+ float3 y = cross(x, ab);
+ int adaptive_n_rots = min(int(odist / res * (2 * M_PI)), n_rots);
+ // int adaptive_n_rots = n_rots;
+ for (int i = 0; i < adaptive_n_rots; i++) {
+ float angle = i * 2 * M_PI / adaptive_n_rots;
+ float3 offset = cos(angle) * x + sin(angle) * y;
+ float3 center_grid = (c + offset - make_float3(corner[0], corner[1], corner[2])) / res;
+ if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 ||
+ center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) {
+ continue;
+ }
+ int3 center_grid_floor = make_int3(center_grid);
+ int3 center_grid_ceil = center_grid_floor + 1;
+ float3 residual = fracf(center_grid);
+
+ float3 w0 = 1.f - residual;
+ float3 w1 = residual;
+
+ float lll = w0.x * w0.y * w0.z;
+ float llh = w0.x * w0.y * w1.z;
+ float lhl = w0.x * w1.y * w0.z;
+ float lhh = w0.x * w1.y * w1.z;
+ float hll = w1.x * w0.y * w0.z;
+ float hlh = w1.x * w0.y * w1.z;
+ float hhl = w1.x * w1.y * w0.z;
+ float hhh = w1.x * w1.y * w1.z;
+
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], lll * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], llh * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], lhl * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], lhh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], hll * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], hlh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], hhl * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], hhh * prob);
+ }
+ }
+ }
+''', 'ppf_voting')
+
+ppf_retrieval_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r'''
+ #define M_PI 3.14159265358979323846264338327950288
+ extern "C" __global__
+ void ppf_voting_retrieval(
+ const float *point_pairs, const float *outputs, const float *probs, float *grid_obj, const float *corner, const float res,
+ int n_ppfs, int n_rots, int grid_x, int grid_y, int grid_z
+ ) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n_ppfs) {
+ float proj_len = outputs[idx * 2];
+ float odist = outputs[idx * 2 + 1];
+ if (odist < res) return;
+ float3 a = make_float3(point_pairs[idx * 6], point_pairs[idx * 6 + 1], point_pairs[idx * 6 + 2]);
+ float3 b = make_float3(point_pairs[idx * 6 + 3], point_pairs[idx * 6 + 4], point_pairs[idx * 6 + 5]);
+ float3 ab = a - b;
+ if (length(ab) < 1e-7) return;
+ ab /= (length(ab) + 1e-7);
+ float3 c = a - ab * proj_len;
+
+ // float prob = max(probs[a_idx], probs[b_idx]);
+ float prob = probs[idx];
+ float3 co = make_float3(0.f, -ab.z, ab.y);
+ if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f);
+ float3 x = co / (length(co) + 1e-7) * odist;
+ float3 y = cross(x, ab);
+ int adaptive_n_rots = min(int(odist / res * (2 * M_PI)), n_rots);
+ // int adaptive_n_rots = n_rots;
+ for (int i = 0; i < adaptive_n_rots; i++) {
+ float angle = i * 2 * M_PI / adaptive_n_rots;
+ float3 offset = cos(angle) * x + sin(angle) * y;
+ float3 center_grid = (c + offset - make_float3(corner[0], corner[1], corner[2])) / res;
+ if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 ||
+ center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) {
+ continue;
+ }
+ int3 center_grid_floor = make_int3(center_grid);
+ int3 center_grid_ceil = center_grid_floor + 1;
+ float3 residual = fracf(center_grid);
+
+ float3 w0 = 1.f - residual;
+ float3 w1 = residual;
+
+ float lll = w0.x * w0.y * w0.z;
+ float llh = w0.x * w0.y * w1.z;
+ float lhl = w0.x * w1.y * w0.z;
+ float lhh = w0.x * w1.y * w1.z;
+ float hll = w1.x * w0.y * w0.z;
+ float hlh = w1.x * w0.y * w1.z;
+ float hhl = w1.x * w1.y * w0.z;
+ float hhh = w1.x * w1.y * w1.z;
+
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], lll * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], llh * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], lhl * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], lhh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], hll * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], hlh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], hhl * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], hhh * prob);
+ }
+ }
+ }
+''', 'ppf_voting_retrieval')
+
+ppf_direct_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r'''
+ #define M_PI 3.14159265358979323846264338327950288
+ extern "C" __global__
+ void ppf_voting_direct(
+ const float *points, const float *outputs, const float *probs, const int *point_idxs, float *grid_obj, const float *corner, const float res,
+ int n_ppfs, int grid_x, int grid_y, int grid_z
+ ) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n_ppfs) {
+ int a_idx = point_idxs[idx];
+ float3 c = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]);
+ float prob = probs[idx];
+ float3 offset = make_float3(outputs[idx * 3], outputs[idx * 3 + 1], outputs[idx * 3 + 2]);
+ float3 center_grid = (c - offset - make_float3(corner[0], corner[1], corner[2])) / res;
+ if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 ||
+ center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) {
+ return;
+ }
+ int3 center_grid_floor = make_int3(center_grid);
+ int3 center_grid_ceil = center_grid_floor + 1;
+ float3 residual = fracf(center_grid);
+
+ float3 w0 = 1.f - residual;
+ float3 w1 = residual;
+
+ float lll = w0.x * w0.y * w0.z;
+ float llh = w0.x * w0.y * w1.z;
+ float lhl = w0.x * w1.y * w0.z;
+ float lhh = w0.x * w1.y * w1.z;
+ float hll = w1.x * w0.y * w0.z;
+ float hlh = w1.x * w0.y * w1.z;
+ float hhl = w1.x * w1.y * w0.z;
+ float hhh = w1.x * w1.y * w1.z;
+
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], lll * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], llh * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], lhl * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], lhh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], hll * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], hlh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], hhl * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], hhh * prob);
+ }
+ }
+''', 'ppf_voting_direct')
+
+
+rot_voting_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r'''
+ #define M_PI 3.14159265358979323846264338327950288
+ extern "C" __global__
+ void rot_voting(
+ const float *points, const float *preds_rot, float3 *outputs_up, const int *point_idxs,
+ int n_ppfs, int n_rots
+ ) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n_ppfs) {
+ float rot = preds_rot[idx];
+ int a_idx = point_idxs[idx * 2];
+ int b_idx = point_idxs[idx * 2 + 1];
+ float3 a = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]);
+ float3 b = make_float3(points[b_idx * 3], points[b_idx * 3 + 1], points[b_idx * 3 + 2]);
+ float3 ab = a - b;
+ if (length(ab) < 1e-7) return;
+ ab /= length(ab);
+
+ float3 co = make_float3(0.f, -ab.z, ab.y);
+ if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f);
+ float3 x = co / (length(co) + 1e-7);
+ float3 y = cross(x, ab);
+
+ for (int i = 0; i < n_rots; i++) {
+ float angle = i * 2 * M_PI / n_rots;
+ float3 offset = cos(angle) * x + sin(angle) * y;
+ float3 up = tan(rot) * offset + (tan(rot) > 0 ? ab : -ab);
+ up = up / (length(up) + 1e-7);
+ outputs_up[idx * n_rots + i] = up;
+ }
+ }
+ }
+''', 'rot_voting')
+
+rot_voting_retrieval_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r'''
+ #define M_PI 3.14159265358979323846264338327950288
+ extern "C" __global__
+ void rot_voting_retrieval(
+ const float *point_pairs, const float *preds_rot, float3 *outputs_up,
+ int n_ppfs, int n_rots
+ ) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n_ppfs) {
+ float rot = preds_rot[idx];
+ float3 a = make_float3(point_pairs[idx * 6], point_pairs[idx * 6 + 1], point_pairs[idx * 6 + 2]);
+ float3 b = make_float3(point_pairs[idx * 6 + 3], point_pairs[idx * 6 + 4], point_pairs[idx * 6 + 5]);
+ float3 ab = a - b;
+ if (length(ab) < 1e-7) return;
+ ab /= length(ab);
+
+ float3 co = make_float3(0.f, -ab.z, ab.y);
+ if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f);
+ float3 x = co / (length(co) + 1e-7);
+ float3 y = cross(x, ab);
+
+ for (int i = 0; i < n_rots; i++) {
+ float angle = i * 2 * M_PI / n_rots;
+ float3 offset = cos(angle) * x + sin(angle) * y;
+ float3 up = tan(rot) * offset + (tan(rot) > 0 ? ab : -ab);
+ up = up / (length(up) + 1e-7);
+ outputs_up[idx * n_rots + i] = up;
+ }
+ }
+ }
+''', 'rot_voting_retrieval')
+
+
+ppf4d_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r'''
+ #define M_PI 3.14159265358979323846264338327950288
+ extern "C" __global__
+ void ppf4d_voting(
+ const float *points, const float *outputs, const float *rot_outputs, const float *probs, const int *point_idxs, float *grid_obj, const float *corner, const float res,
+ int n_ppfs, int n_rots, int grid_x, int grid_y, int grid_z, int grid_w
+ ) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n_ppfs) {
+ float proj_len = outputs[idx * 2];
+ float odist = outputs[idx * 2 + 1];
+ if (odist < res) return;
+ int a_idx = point_idxs[idx * 2];
+ int b_idx = point_idxs[idx * 2 + 1];
+ float3 a = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]);
+ float3 b = make_float3(points[b_idx * 3], points[b_idx * 3 + 1], points[b_idx * 3 + 2]);
+ float3 ab = a - b;
+ if (length(ab) < 1e-7) return;
+ ab /= (length(ab) + 1e-7);
+ float3 c = a - ab * proj_len;
+
+ // float prob = max(probs[a_idx], probs[b_idx]);
+ float prob = probs[idx];
+ float3 co = make_float3(0.f, -ab.z, ab.y);
+ if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f);
+ float3 x = co / (length(co) + 1e-7) * odist;
+ float3 y = cross(x, ab);
+ int adaptive_n_rots = min(int(odist / res * (2 * M_PI)), n_rots);
+ // int adaptive_n_rots = n_rots;
+ for (int i = 0; i < adaptive_n_rots; i++) {
+ float angle = i * 2 * M_PI / adaptive_n_rots;
+ float3 offset = cos(angle) * x + sin(angle) * y;
+ float3 center_grid = (c + offset - make_float3(corner[0], corner[1], corner[2])) / res;
+ if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 ||
+ center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) {
+ continue;
+ }
+ int3 center_grid_floor = make_int3(center_grid);
+ int3 center_grid_ceil = center_grid_floor + 1;
+ float3 residual = fracf(center_grid);
+
+ float3 w0 = 1.f - residual;
+ float3 w1 = residual;
+
+ float lll = w0.x * w0.y * w0.z;
+ float llh = w0.x * w0.y * w1.z;
+ float lhl = w0.x * w1.y * w0.z;
+ float lhh = w0.x * w1.y * w1.z;
+ float hll = w1.x * w0.y * w0.z;
+ float hlh = w1.x * w0.y * w1.z;
+ float hhl = w1.x * w1.y * w0.z;
+ float hhh = w1.x * w1.y * w1.z;
+
+ for (int j = 0; j < grid_w; j++) {
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * lll * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * llh * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * lhl * prob);
+ atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * lhh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * hll * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * hlh * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * hhl * prob);
+ atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * hhh * prob);
+ }
+ }
+ }
+ }
+''', 'ppf4d_voting')
diff --git a/outputs/.gitkeep b/outputs/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/real_eval.py b/real_eval.py
new file mode 100644
index 0000000..4c2efca
--- /dev/null
+++ b/real_eval.py
@@ -0,0 +1,257 @@
+import os
+import time
+import tqdm
+import numpy as np
+import transformations as tf
+
+from envs.real_camera import CameraL515
+from envs.real_robot import Panda
+from utilities.data_utils import transform_pc, transform_dir
+from utilities.vis_utils import visualize
+from utilities.network_utils import send, read
+
+
+if __name__ == '__main__':
+ print("please manually set the robot to a specific pose, make sure remote services all running")
+
+ cam2EE_path = "/home/franka/junbo/data/robot/l515_franka.npy"
+ cam2EE = np.load(cam2EE_path)
+ camera_loaded = False
+ robot_loaded = False
+ vis = True
+ temp_observation_path = "./temp_data/observation.npz"
+ temp_service_path = "./temp_data/service.npz"
+ temp_flag_path = "./temp_data/flag.npy"
+ remote_repo_path = "TODO" # TODO: set your own remote repo path
+ remote_observation_path = f"{remote_repo_path}/temp_data/observation.npz"
+ remote_service_path = f"{remote_repo_path}/temp_data/service.npz"
+ remote_flag_path = f"{remote_repo_path}/temp_data/flag.npy"
+ remote_ip = "TODO" # TODO: set your own remote ip
+ port = TODO # TODO: set your own remote port
+ username = "TODO" # TODO: set your own remote username
+ key_filename = "TODO" # TODO: set your own local key file path
+ task = -1 # close as 1, open as -1
+ time_steps = 10
+
+ try:
+ print("===> initializing camera")
+ start_time = time.time()
+ camera = CameraL515()
+ camera_loaded = True
+ end_time = time.time()
+ print("===> camera initialized", end_time - start_time)
+
+ print("===> initializing robot")
+ start_time = time.time()
+ robot = Panda()
+ robot.gripper_open()
+ robot.homing()
+ robot_loaded = True
+ end_time = time.time()
+ print("===> robot initialized", end_time - start_time)
+
+ print("===> getting observation")
+ start_time = time.time()
+ color, depth = camera.get_data(hole_filling=False)
+ depth_sensor = camera.pipeline_profile.get_device().first_depth_sensor()
+ depth_scale = depth_sensor.get_depth_scale()
+ xyzrgb = camera.getXYZRGB(color, depth, np.identity(4), np.identity(4), camera.getIntrinsics(), inpaint=False, depth_scale=depth_scale)
+ # xyzrgb = xyzrgb[xyzrgb[:, 2] <= 1.5, :]
+ xyzrgb = xyzrgb[xyzrgb[:, 2] > 0.05, :]
+ cam_pc = xyzrgb[:, 0:3]
+ pc_color = xyzrgb[:, 3:6]
+ end_time = time.time()
+ print("===> observation got", end_time - start_time)
+
+ print("===> preprocessing observation")
+ start_time = time.time()
+ EE2robot = robot.readPose()
+ cam2robot = EE2robot @ cam2EE
+ robot2cam = np.linalg.inv(cam2robot)
+ base_pc = transform_pc(cam_pc, cam2robot)
+ space_mask_x = np.logical_and(base_pc[:, 0] > 0, base_pc[:, 0] < 1.1)
+ space_mask_y = np.logical_and(base_pc[:, 1] > -0.27, base_pc[:, 1] < 0.55)
+ # space_mask_z = base_pc[:, 2] > 0.02
+ # space_mask_z = base_pc[:, 2] > 0.55 # microwave: pad + safe (rotate)
+ # space_mask_z = base_pc[:, 2] > 0.52 # refrigerator: storagefurniture
+ # space_mask_z = base_pc[:, 2] > 0.4 # safe: pad + microwave
+ # space_mask_z = base_pc[:, 2] > 0.27 # storagefurniture: microwave
+ # space_mask_z = base_pc[:, 2] > 0.27 # drawer: microwave
+ space_mask_z = base_pc[:, 2] > 0.4 # washingmachine: pad + microwave
+ space_mask = np.logical_and(np.logical_and(space_mask_x, space_mask_y), space_mask_z)
+ base_pc_space = base_pc[space_mask, :]
+ pc_color_space = pc_color[space_mask, :]
+ cam_pc_space = transform_pc(base_pc_space, robot2cam)
+ end_time = time.time()
+ print("===> observation preprocessed", end_time - start_time)
+ np.savez("./observation.npz", point_cloud=cam_pc_space, rgb=pc_color_space)
+ if vis:
+ visualize(cam_pc_space, pc_color_space, whether_frame=True, whether_bbox=True, window_name="observation")
+
+ print("===> sending request")
+ start_time = time.time()
+ np.savez(temp_observation_path, point_cloud=cam_pc_space, rgb=pc_color_space)
+ time.sleep(0.5)
+ while not (os.path.isfile(temp_observation_path) and os.access(temp_observation_path, os.R_OK)):
+ time.sleep(0.1)
+ send(temp_observation_path, remote_observation_path,
+ remote_ip=remote_ip, port=port, username=username, key_filename=key_filename)
+ time.sleep(0.5)
+ os.remove(temp_observation_path)
+ end_time = time.time()
+ print("===> request sent", end_time - start_time)
+
+ print("===> reading response")
+ start_time = time.time()
+ while True:
+ read(temp_flag_path, remote_flag_path,
+ remote_ip=remote_ip, port=port, username=username, key_filename=key_filename)
+ time.sleep(0.5)
+ got_service = np.load(temp_flag_path).item()
+ if got_service:
+ os.remove(temp_flag_path)
+ break
+ else:
+ time.sleep(0.5)
+ read(temp_service_path, remote_service_path,
+ remote_ip=remote_ip, port=port, username=username, key_filename=key_filename)
+ time.sleep(0.5)
+ service = np.load(temp_service_path, allow_pickle=True)
+ num_grasps = service['num_grasps']
+ if num_grasps == 0:
+ print("no grasps detected")
+ else:
+ cam_joint_base = service['joint_base']
+ cam_joint_direction = service['joint_direction']
+ cam_affordable_position = service['affordable_position']
+ joint_type = service['joint_type']
+ joint_re = service['joint_re']
+ grasp_score = service['grasp_score']
+ grasp_width = service['grasp_width']
+ grasp_depth = service['grasp_depth']
+ grasp_affordance = service['grasp_affordance']
+ cam_grasp_translation = service['grasp_translation']
+ cam_grasp_rotation = service['grasp_rotation']
+ cam_grasp_pose = np.eye(4)
+ cam_grasp_pose[:3, 3] = cam_grasp_translation
+ cam_grasp_pose[:3, :3] = cam_grasp_rotation
+ base_joint_base = transform_pc(cam_joint_base[None, :], cam2robot)[0]
+ base_joint_direction = transform_dir(cam_joint_direction[None, :], cam2robot)[0]
+ base_affordable_position = transform_pc(cam_affordable_position[None, :], cam2robot)[0]
+ base_grasp_pose = cam2robot @ cam_grasp_pose
+ base_grasp_pose[:3, 3] += (grasp_depth - 0.05) * base_grasp_pose[:3, 0] # TODO: hardcode to avoid collision
+ if joint_type == 0:
+ # TODO: only for horizontal grasp to avoid singular robot state
+ flip = np.arccos(np.dot(base_grasp_pose[:3, 2], np.array([0., 0., 1.]))) / np.pi * 180.0 < 45
+ if flip:
+ print("flipped")
+ base_grasp_pose[:3, 1] = -base_grasp_pose[:3, 1]
+ base_grasp_pose[:3, 2] = -base_grasp_pose[:3, 2]
+ rotate = base_grasp_pose[:3, 0][2] > 0
+ if rotate:
+ print("rotated")
+ target_x_axis = base_grasp_pose[:3, 0].copy()
+ target_x_axis[2] = -target_x_axis[2]
+ rotation_angle = np.arccos(np.dot(base_grasp_pose[:3, 0], target_x_axis))
+ rotation_direction = np.array([base_grasp_pose[:3, 0][0], base_grasp_pose[:3, 0][1]])
+ rotation_direction /= np.linalg.norm(rotation_direction)
+ rotation_direction = np.array([-rotation_direction[1], rotation_direction[0], 0.])
+ rotation_matrix = tf.rotation_matrix(angle=rotation_angle, direction=rotation_direction, point=base_grasp_pose[:3, 3])
+ base_grasp_pose = rotation_matrix @ base_grasp_pose
+ elif joint_type == 1:
+ horizontal = np.arccos(np.dot(base_grasp_pose[:3, 0], np.array([1., 0., 0.]))) / np.pi * 180.0 < 45
+ if horizontal:
+ print("horizontal")
+ else:
+ print("vertical")
+ else:
+ raise ValueError
+ base_pre_grasp_pose = base_grasp_pose.copy()
+ base_pre_grasp_pose[:3, 3] -= 0.05 * base_pre_grasp_pose[:3, 0]
+ g2g = np.array([[0., 0., -1.], [0., -1., 0.], [-1., 0., 0.]])
+ base_gripper_pose = np.eye(4)
+ base_gripper_pose[:3, :3] = base_grasp_pose[:3, :3] @ g2g
+ base_gripper_pose[:3, 3] = base_grasp_pose[:3, 3]
+ base_pre_gripper_pose = np.eye(4)
+ base_pre_gripper_pose[:3, :3] = base_pre_grasp_pose[:3, :3] @ g2g
+ base_pre_gripper_pose[:3, 3] = base_pre_grasp_pose[:3, 3]
+ np.savez("./joint.npz", joint_base=base_joint_base, joint_direction=base_joint_direction, affordable_position=base_affordable_position)
+ np.savez("./grasp.npz", grasp_pose=base_grasp_pose, grasp_width=grasp_width)
+ time.sleep(0.5)
+ os.remove(temp_service_path)
+ end_time = time.time()
+ print("===> response read", end_time - start_time)
+ if vis:
+ if num_grasps != 0:
+ visualize(base_pc_space, pc_color_space,
+ joint_translations=base_joint_base[None, :], joint_rotations=base_joint_direction[None, :], affordable_positions=base_affordable_position[None, :],
+ grasp_poses=base_grasp_pose[None, ...], grasp_widths=np.array([grasp_width]), grasp_depths=np.array([0.]), grasp_affordances=np.array([grasp_affordance]),
+ whether_frame=True, whether_bbox=True, window_name="prediction")
+
+ print("===> resetting flag")
+ serviced = np.array(False)
+ np.save(temp_flag_path, serviced)
+ time.sleep(0.5)
+ while not (os.path.isfile(temp_flag_path) and os.access(temp_flag_path, os.R_OK)):
+ time.sleep(0.1)
+ send(temp_flag_path, remote_flag_path,
+ remote_ip=remote_ip, port=port, username=username, key_filename=key_filename)
+ time.sleep(0.5)
+ os.remove(temp_flag_path)
+ end_time = time.time()
+ print("===> flag reset", end_time - start_time)
+
+ print("===> starting manipulation")
+ import pdb; pdb.set_trace()
+ start_time = time.time()
+ if num_grasps == 0:
+ exit(1)
+ else:
+ robot.move_gripper(grasp_width)
+
+ robot.movePose(base_pre_gripper_pose)
+
+ robot.movePose(base_gripper_pose)
+
+ # grasp
+ is_graspped = robot.gripper_close()
+ is_graspped = is_graspped and robot.is_grasping()
+ print(is_graspped)
+
+ # move
+ real_trajectory = []
+ target_trajectory = []
+ wrench_trajectory = []
+ robot.start_impedance_control()
+ for time_step in tqdm.trange(time_steps):
+ current_EE2robot = robot.readPose()
+ current_wrench= robot.readWrench()
+ if joint_type == 0:
+ rotation_angle = -5.0 * task * joint_re / 180.0 * np.pi
+ delta_pose = tf.rotation_matrix(angle=rotation_angle, direction=base_joint_direction, point=base_joint_base)
+ elif joint_type == 1:
+ translation_distance = -5.0 * task / 100.0
+ delta_pose = tf.translation_matrix(base_joint_direction * translation_distance)
+ else:
+ raise ValueError
+ target_EE2robot = delta_pose @ current_EE2robot
+ robot.movePose(target_EE2robot)
+ time.sleep(0.3)
+ real_trajectory.append(current_EE2robot)
+ target_trajectory.append(target_EE2robot)
+ wrench_trajectory.append(current_wrench)
+ robot.end_impedance_control()
+ real_trajectory = np.array(real_trajectory)
+ target_trajectory = np.array(target_trajectory)
+ wrench_trajectory = np.array(wrench_trajectory)
+ np.savez("./trajectory.npz", real_trajectory=real_trajectory, target_trajectory=target_trajectory, wrench_trajectory=wrench_trajectory)
+ end_time = time.time()
+ print("===> manipulation done", end_time - start_time)
+ robot.gripper_open()
+ # robot.homing()
+ except Exception as e:
+ print(e)
+ if camera_loaded:
+ del camera
+ if robot_loaded:
+ del robot
diff --git a/real_service.py b/real_service.py
new file mode 100644
index 0000000..4ba987a
--- /dev/null
+++ b/real_service.py
@@ -0,0 +1,278 @@
+import os
+import configargparse
+from omegaconf import OmegaConf
+import time
+import numpy as np
+import torch
+
+from utilities.env_utils import setup_seed
+from utilities.data_utils import transform_pc, transform_dir
+from utilities.metrics_utils import invaffordance_metrics, invaffordances2affordance
+from utilities.constants import seed, max_grasp_width
+
+
+def config_parse() -> configargparse.Namespace:
+ parser = configargparse.ArgumentParser()
+
+ # data config
+ parser.add_argument('--cat', type=str, default='Microwave', help='the category of the object')
+ # model config
+ parser.add_argument('--roartnet', action='store_true', help='whether call roartnet')
+ parser.add_argument('--roartnet_config_path', type=str, default='./configs/eval_config.yaml', help='the path to roartnet config')
+ # grasp config
+ parser.add_argument('--graspnet', action='store_true', help='whether call graspnet')
+ parser.add_argument('--gsnet_weight_path', type=str, default='./weights/checkpoint_detection.tar', help='the path to graspnet weight')
+ parser.add_argument('--max_grasp_width', type=float, default=max_grasp_width, help='the max width of the gripper')
+ # task config
+ parser.add_argument('--selected_part', type=int, default=0, help='the selected part of the object')
+ # others
+ parser.add_argument('--seed', type=int, default=seed, help='the random seed')
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ print("please clear temporary data directory")
+
+ args = config_parse()
+ setup_seed(args.seed)
+ temp_request_path = './temp_data/observation.npz'
+ temp_response_path = './temp_data/service.npz'
+ temp_flag_path = './temp_data/flag.npy'
+ if args.cat == "Microwave":
+ joint_types = [0]
+ joint_res = [-1]
+ elif args.cat == "Refrigerator":
+ joint_types = [0]
+ joint_res = [1]
+ elif args.cat == "Safe":
+ joint_types = [0]
+ joint_res = [1]
+ elif args.cat == "StorageFurniture":
+ joint_types = [1, 0]
+ joint_res = [0, -1]
+ elif args.cat == "Drawer":
+ joint_types = [1, 1, 1]
+ joint_res = [0, 0, 0]
+ elif args.cat == "WashingMachine":
+ joint_types = [0]
+ joint_res = [-1]
+ else:
+ raise ValueError(f"Unknown category {args.cat}")
+
+ if args.roartnet:
+ print("===> loading roartnet")
+ start_time = time.time()
+ from models.roartnet import create_shot_encoder, create_encoder
+ from inference import inference_fn as roartnet_inference_fn
+ roartnet_cfg = OmegaConf.load(args.roartnet_config_path)
+ trained_path = roartnet_cfg.trained.path[args.cat]
+ trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml")
+ roartnet_cfg = OmegaConf.merge(trained_cfg, roartnet_cfg)
+ joint_num = roartnet_cfg.dataset.joint_num
+ resolution = roartnet_cfg.dataset.resolution
+ receptive_field = roartnet_cfg.dataset.receptive_field
+ has_rgb = roartnet_cfg.dataset.rgb
+ denoise = roartnet_cfg.dataset.denoise
+ normalize = roartnet_cfg.dataset.normalize
+ sample_points_num = roartnet_cfg.dataset.sample_points_num
+ sample_tuples_num = roartnet_cfg.algorithm.sampling.sample_tuples_num
+ tuple_more_num = roartnet_cfg.algorithm.sampling.tuple_more_num
+ shot_hidden_dims = roartnet_cfg.algorithm.shot_encoder.hidden_dims
+ shot_feature_dim = roartnet_cfg.algorithm.shot_encoder.feature_dim
+ shot_bn = roartnet_cfg.algorithm.shot_encoder.bn
+ shot_ln = roartnet_cfg.algorithm.shot_encoder.ln
+ shot_dropout = roartnet_cfg.algorithm.shot_encoder.dropout
+ shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim,
+ shot_bn, shot_ln, shot_dropout)
+ shot_encoder.load_state_dict(torch.load(f'{trained_path}/weights/shot_encoder_latest.pth', map_location=torch.device('cuda')))
+ shot_encoder = shot_encoder.cuda()
+ shot_encoder.eval()
+ overall_hidden_dims = roartnet_cfg.algorithm.encoder.hidden_dims
+ rot_bin_num = roartnet_cfg.algorithm.voting.rot_bin_num
+ overall_bn = roartnet_cfg.algorithm.encoder.bn
+ overall_ln = roartnet_cfg.algorithm.encoder.ln
+ overall_dropout = roartnet_cfg.algorithm.encoder.dropout
+ encoder = create_encoder(tuple_more_num, shot_feature_dim, has_rgb, overall_hidden_dims, rot_bin_num, joint_num,
+ overall_bn, overall_ln, overall_dropout)
+ encoder.load_state_dict(torch.load(f'{trained_path}/weights/encoder_latest.pth', map_location=torch.device('cuda')))
+ encoder = encoder.cuda()
+ encoder.eval()
+ voting_num = roartnet_cfg.algorithm.voting.voting_num
+ angle_tol = roartnet_cfg.algorithm.voting.angle_tol
+ translation2pc = roartnet_cfg.algorithm.voting.translation2pc
+ multi_candidate = roartnet_cfg.algorithm.voting.multi_candidate
+ candidate_threshold = roartnet_cfg.algorithm.voting.candidate_threshold
+ rotation_multi_neighbor = roartnet_cfg.algorithm.voting.rotation_multi_neighbor
+ neighbor_threshold = roartnet_cfg.algorithm.voting.neighbor_threshold
+ rotation_cluster = roartnet_cfg.algorithm.voting.rotation_cluster
+ bmm_size = roartnet_cfg.algorithm.voting.bmm_size
+ end_time = time.time()
+ print(f"===> loaded roartnet {end_time - start_time}")
+
+ if args.graspnet:
+ print("===> loading graspnet")
+ start_time = time.time()
+ from munch import DefaultMunch
+ from gsnet import AnyGrasp
+ grasp_detector_cfg = {
+ 'checkpoint_path': args.gsnet_weight_path,
+ 'max_gripper_width': args.max_grasp_width,
+ 'gripper_height': 0.03,
+ 'top_down_grasp': False,
+ 'add_vdistance': True,
+ 'debug': True
+ }
+ grasp_detector_cfg = DefaultMunch.fromDict(grasp_detector_cfg)
+ grasp_detector = AnyGrasp(grasp_detector_cfg)
+ grasp_detector.load_net()
+ end_time = time.time()
+ print(f"===> loaded graspnet {end_time - start_time}")
+
+ while True:
+ print("===> listening to request")
+ start_time = time.time()
+ serviced = np.array(False)
+ np.save(temp_flag_path, serviced)
+ got_request = False
+ while not got_request:
+ got_request = os.path.exists(temp_request_path)
+ time.sleep(5.0) # NOTE: hardcode to be longer than the writing time
+ if got_request:
+ while not (os.path.isfile(temp_request_path) and os.access(temp_request_path, os.R_OK)):
+ time.sleep(0.1)
+ else:
+ time.sleep(0.1)
+ observation = np.load(temp_request_path, allow_pickle=True)
+ cam_pc = observation['point_cloud']
+ pc_rgb = observation['rgb']
+ c2c = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
+ cam_pc_model = transform_pc(cam_pc, c2c)
+ time.sleep(0.5)
+ os.remove(temp_request_path)
+ end_time = time.time()
+ print(f"===> got request {end_time - start_time}")
+
+ print("===> inferencing model")
+ if args.roartnet:
+ start_time = time.time()
+ pred_joint_bases, pred_joint_directions, pred_affordable_positions = roartnet_inference_fn(cam_pc_model, pc_rgb if has_rgb else None, shot_encoder, encoder,
+ denoise, normalize, resolution, receptive_field, sample_points_num, sample_tuples_num, tuple_more_num,
+ voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold, bmm_size, joint_num, device=0)
+ pred_selected_joint_base = pred_joint_bases[args.selected_part]
+ pred_selected_joint_direction = pred_joint_directions[args.selected_part]
+ pred_selected_affordable_position = pred_affordable_positions[args.selected_part]
+ pred_selected_joint_base = transform_pc(pred_selected_joint_base[None, :], np.linalg.inv(c2c))[0]
+ pred_selected_joint_direction = transform_dir(pred_selected_joint_direction[None, :], np.linalg.inv(c2c))[0]
+ pred_selected_affordable_position = transform_pc(pred_selected_affordable_position[None, :], np.linalg.inv(c2c))[0]
+ end_time = time.time()
+ print(f"===> shot_dropout predicted {end_time - start_time}")
+
+ print("===> detecting grasps")
+ if args.graspnet:
+ start_time = time.time()
+ # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=None, lims=[
+ # np.floor(np.min(pcd_grasp[:, 0])) - 0.1, np.ceil(np.max(pcd_grasp[:, 0])) + 0.1,
+ # np.floor(np.min(pcd_grasp[:, 1])) - 0.1, np.ceil(np.max(pcd_grasp[:, 1])) + 0.1,
+ # np.floor(np.min(pcd_grasp[:, 2])) - 0.1, np.ceil(np.max(pcd_grasp[:, 2])) + 0.1])
+ # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=[-float('inf'), float('inf'), -float('inf'), float('inf'), -float('inf'), float('inf')])
+ # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, apply_object_mask=True, dense_grasp=False, collision_detection=True)
+ try:
+ gg_grasp = grasp_detector.get_grasp(cam_pc.astype(np.float32), colors=pc_rgb, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='fast')
+ except:
+ gg_grasp = grasp_detector.get_grasp(cam_pc.astype(np.float32), colors=pc_rgb, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='slow')
+ if gg_grasp is None:
+ gg_grasp = []
+ else:
+ if len(gg_grasp) != 2:
+ gg_grasp = []
+ else:
+ gg_grasp, pcd_o3d = gg_grasp
+ gg_grasp = gg_grasp.nms().sort_by_score()
+ # grippers_o3d = gg_grasp.to_open3d_geometry_list()
+ # frame_o3d = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
+ # o3d.visualization.draw_geometries([*grippers_o3d, pcd_o3d, frame_o3d])
+ if len(gg_grasp) == 0:
+ np.savez(temp_response_path,
+ joint_base=pred_selected_joint_base,
+ joint_direction=pred_selected_joint_direction,
+ affordable_position=pred_selected_affordable_position,
+ joint_type=joint_types[args.selected_part],
+ joint_re=joint_res[args.selected_part],
+ num_grasps=0)
+ while not (os.path.isfile(temp_response_path) and os.access(temp_response_path, os.R_OK)):
+ time.sleep(0.1)
+ serviced = np.array(True)
+ np.save(temp_flag_path, serviced)
+ while True:
+ while not (os.path.isfile(temp_flag_path) and os.access(temp_flag_path, os.R_OK)):
+ time.sleep(0.1)
+ serviced = np.load(temp_flag_path).item()
+ if not serviced:
+ os.remove(temp_flag_path)
+ os.remove(temp_response_path)
+ break
+ else:
+ time.sleep(0.1)
+ continue
+ grasp_scores, grasp_widths, grasp_depths, grasp_translations, grasp_rotations, grasp_invaffordances = [], [], [], [], [], []
+ for g_idx, g_grasp in enumerate(gg_grasp):
+ grasp_score = g_grasp.score
+ grasp_scores.append(grasp_score)
+ grasp_width = g_grasp.width
+ grasp_widths.append(grasp_width)
+ grasp_depth = g_grasp.depth
+ grasp_depths.append(grasp_depth)
+ grasp_translation = g_grasp.translation
+ grasp_translations.append(grasp_translation)
+ grasp_rotation = g_grasp.rotation_matrix
+ grasp_rotations.append(grasp_rotation)
+ grasp_invaffordance = invaffordance_metrics(grasp_translation, grasp_rotation, grasp_score, pred_selected_affordable_position,
+ pred_selected_joint_base, pred_selected_joint_direction, joint_types[args.selected_part])
+ grasp_invaffordances.append(grasp_invaffordance)
+ grasp_affordances = invaffordances2affordance(grasp_invaffordances)
+ selected_grasp_idx = np.argmax(grasp_affordances)
+ selected_grasp_score = grasp_scores[selected_grasp_idx]
+ selected_grasp_width = grasp_widths[selected_grasp_idx]
+ selected_grasp_width = max(min(selected_grasp_width * 1.5, args.max_grasp_width), 0.0)
+ selected_grasp_depth = grasp_depths[selected_grasp_idx]
+ selected_grasp_translation = grasp_translations[selected_grasp_idx]
+ selected_grasp_rotation = grasp_rotations[selected_grasp_idx]
+ selected_grasp_affordance = grasp_affordances[selected_grasp_idx]
+ end_time = time.time()
+ print(f"===> anygrasp detected {end_time - start_time} {len(gg_grasp)}")
+
+ print("===> sending response")
+ start_time = time.time()
+ np.savez(temp_response_path,
+ joint_base=pred_selected_joint_base,
+ joint_direction=pred_selected_joint_direction,
+ affordable_position=pred_selected_affordable_position,
+ joint_type=joint_types[args.selected_part],
+ joint_re=joint_res[args.selected_part],
+ num_grasps=len(gg_grasp),
+ grasp_score=selected_grasp_score,
+ grasp_width=selected_grasp_width,
+ grasp_depth=selected_grasp_depth,
+ grasp_translation=selected_grasp_translation,
+ grasp_rotation=selected_grasp_rotation,
+ grasp_affordance=selected_grasp_affordance)
+ while not (os.path.isfile(temp_response_path) and os.access(temp_response_path, os.R_OK)):
+ time.sleep(0.1)
+ serviced = np.array(True)
+ np.save(temp_flag_path, serviced)
+ while True:
+ while not (os.path.isfile(temp_flag_path) and os.access(temp_flag_path, os.R_OK)):
+ time.sleep(0.1)
+ serviced = np.load(temp_flag_path).item()
+ if not serviced:
+ os.remove(temp_flag_path)
+ os.remove(temp_response_path)
+ break
+ else:
+ time.sleep(0.1)
+ end_time = time.time()
+ print(f"===> sent response {end_time - start_time}")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..14e8b37
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+requests
+ConfigArgParse
+trimesh
+tqdm
+Pillow
+scipy
+open3d==0.18.0
+opencv-python
+omegaconf
+hydra-core
+scikit-learn
+tensorboard
+transformations
+paramiko
diff --git a/scripts/eval_gt.sh b/scripts/eval_gt.sh
new file mode 100644
index 0000000..7d63041
--- /dev/null
+++ b/scripts/eval_gt.sh
@@ -0,0 +1,11 @@
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task pull --abbr 7119 --seed 42
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task push --abbr 7119 --seed 43
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task pull --abbr 7263 --seed 44
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task push --abbr 7263 --seed 45
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task pull --abbr 7296 --seed 46
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task push --abbr 7296 --seed 47
diff --git a/scripts/eval_roartnet.sh b/scripts/eval_roartnet.sh
new file mode 100644
index 0000000..26fb479
--- /dev/null
+++ b/scripts/eval_roartnet.sh
@@ -0,0 +1,11 @@
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task pull --abbr 7119 --seed 42
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task push --abbr 7119 --seed 43
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task pull --abbr 7263 --seed 44
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task push --abbr 7263 --seed 45
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task pull --abbr 7296 --seed 46
+
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task push --abbr 7296 --seed 47
diff --git a/scripts/real_service.sh b/scripts/real_service.sh
new file mode 100644
index 0000000..f684545
--- /dev/null
+++ b/scripts/real_service.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python real_service.py --cat Microwave --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --selected_part 0 --seed 42
diff --git a/scripts/test.sh b/scripts/test.sh
new file mode 100644
index 0000000..a4a095a
--- /dev/null
+++ b/scripts/test.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python test.py
diff --git a/scripts/test_config.sh b/scripts/test_config.sh
new file mode 100644
index 0000000..7d0304e
--- /dev/null
+++ b/scripts/test_config.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python test.py --cfg job
diff --git a/scripts/test_gt.sh b/scripts/test_gt.sh
new file mode 100644
index 0000000..7e734e0
--- /dev/null
+++ b/scripts/test_gt.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python test_gt.py
diff --git a/scripts/test_gt_config.sh b/scripts/test_gt_config.sh
new file mode 100644
index 0000000..6edeae5
--- /dev/null
+++ b/scripts/test_gt_config.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python test_gt.py --cfg job
diff --git a/scripts/test_real.sh b/scripts/test_real.sh
new file mode 100644
index 0000000..21dda9b
--- /dev/null
+++ b/scripts/test_real.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python test_real.py
diff --git a/scripts/test_real_config.sh b/scripts/test_real_config.sh
new file mode 100644
index 0000000..8248ea1
--- /dev/null
+++ b/scripts/test_real_config.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python test_real.py --cfg job
diff --git a/scripts/train.sh b/scripts/train.sh
new file mode 100644
index 0000000..5c46993
--- /dev/null
+++ b/scripts/train.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python train.py
diff --git a/scripts/train_config.sh b/scripts/train_config.sh
new file mode 100644
index 0000000..58a18ce
--- /dev/null
+++ b/scripts/train_config.sh
@@ -0,0 +1 @@
+export OMP_NUM_THREADS=16; python train.py --cfg job
diff --git a/src_shot/CMakeLists.txt b/src_shot/CMakeLists.txt
new file mode 100644
index 0000000..79ba892
--- /dev/null
+++ b/src_shot/CMakeLists.txt
@@ -0,0 +1,17 @@
+cmake_minimum_required(VERSION 2.8.12)
+project(shot)
+set (CMAKE_CXX_STANDARD 11)
+
+find_package( PythonInterp 3.6 REQUIRED )
+find_package( PythonLibs 3.6 REQUIRED )
+find_package( pybind11 REQUIRED )
+find_package( PCL 1.8 REQUIRED )
+
+# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
+
+include_directories( ${PCL_INCLUDE_DIRS} )
+# link_directories( ${PCL_LIBRARY_DIRS} )
+add_definitions(${PCL_DEFINITIONS})
+
+pybind11_add_module(shot shot.cpp)
+target_link_libraries(shot PUBLIC ${PCL_LIBRARIES})
diff --git a/src_shot/shot.cpp b/src_shot/shot.cpp
new file mode 100644
index 0000000..02af856
--- /dev/null
+++ b/src_shot/shot.cpp
@@ -0,0 +1,165 @@
+/*
+Modified from https://github.com/qq456cvb/CPPF2/blob/main/src_shot/shot.cpp
+*/
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace py = pybind11;
+using namespace pybind11::literals;
+
+py::array_t estimate_normal(py::array_t pc, double normal_r)
+{
+ pcl::PointCloud::Ptr cloud(new pcl::PointCloud);
+ cloud->points.resize(pc.shape(0));
+ float *pc_ptr = (float*)pc.request().ptr;
+ for (int i = 0; i < pc.shape(0); ++i)
+ {
+ std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]);
+ // std::cout << cloud->points[i] << std::endl;
+ pc_ptr += 3;
+ }
+ pcl::PointCloud::Ptr normals(new pcl::PointCloud);
+
+ pcl::NormalEstimation normalEstimation;
+ normalEstimation.setInputCloud(cloud);
+ normalEstimation.setRadiusSearch(normal_r);
+ // normalEstimation.setKSearch(40);
+ pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree);
+ normalEstimation.setSearchMethod(kdtree);
+ normalEstimation.compute(*normals);
+
+ auto result = py::array_t(normals->points.size() * 3);
+ auto buf = result.request();
+ float *ptr = (float*)buf.ptr;
+ for (int i = 0; i < normals->points.size(); ++i)
+ {
+ std::copy(&normals->points[i].normal[0], &normals->points[i].normal[3], &ptr[i * 3]);
+ }
+ return result;
+}
+
+
+py::array_t compute(py::array_t pc, double normal_r, double shot_r)
+{
+ // Object for storing the point cloud.
+ pcl::PointCloud::Ptr cloud(new pcl::PointCloud);
+ cloud->points.resize(pc.shape(0));
+ float *pc_ptr = (float*)pc.request().ptr;
+ for (int i = 0; i < pc.shape(0); ++i)
+ {
+ std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]);
+ // std::cout << cloud->points[i] << std::endl;
+ pc_ptr += 3;
+ }
+
+ // Object for storing the normals.
+ pcl::PointCloud::Ptr normals(new pcl::PointCloud);
+ // Object for storing the SHOT descriptors for each point.
+ pcl::PointCloud::Ptr descriptors(new pcl::PointCloud());
+
+ // Note: you would usually perform downsampling now. It has been omitted here
+ // for simplicity, but be aware that computation can take a long time.
+
+ // Estimate the normals.
+ pcl::NormalEstimation normalEstimation;
+ normalEstimation.setInputCloud(cloud);
+ normalEstimation.setRadiusSearch(normal_r);
+ // normalEstimation.setKSearch(40);
+ pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree);
+ normalEstimation.setSearchMethod(kdtree);
+ normalEstimation.compute(*normals);
+
+ // SHOT estimation object.
+ pcl::SHOTEstimation shot;
+ shot.setInputCloud(cloud);
+ shot.setInputNormals(normals);
+ // The radius that defines which of the keypoint's neighbors are described.
+ // If too large, there may be clutter, and if too small, not enough points may be found.
+ shot.setRadiusSearch(shot_r);
+// shot.setKSearch(40);
+ shot.compute(*descriptors);
+
+ auto result = py::array_t(descriptors->points.size() * 352);
+ auto buf = result.request();
+ float *ptr = (float*)buf.ptr;
+
+ for (int i = 0; i < descriptors->points.size(); ++i)
+ {
+ std::copy(&descriptors->points[i].descriptor[0], &descriptors->points[i].descriptor[352], &ptr[i * 352]);
+ }
+ return result;
+}
+
+py::array_t compute_color(py::array_t pc, py::array_t pc_color, double normal_r, double shot_r)
+{
+ // Object for storing the point cloud.
+ pcl::PointCloud::Ptr cloud(new pcl::PointCloud);
+ cloud->points.resize(pc.shape(0));
+ float *pc_ptr = (float*)pc.request().ptr;
+ float *color_ptr = (float*)pc_color.request().ptr;
+ for (int i = 0; i < pc.shape(0); ++i)
+ {
+ cloud->points[i].x = *pc_ptr;
+ cloud->points[i].y = *(pc_ptr + 1);
+ cloud->points[i].z = *(pc_ptr + 2);
+
+ uint8_t r = (*color_ptr) * 255.f;
+ uint8_t g = (*(color_ptr + 1)) * 255.f;
+ uint8_t b = (*(color_ptr + 2)) * 255.f;
+ uint32_t rgb = ((std::uint32_t)r << 16 | (std::uint32_t)g << 8 | (std::uint32_t)b);
+ cloud->points[i].rgb = *reinterpret_cast(&rgb);
+ // std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]);
+ // std::copy(color_ptr, color_ptr + 3, &cloud->points[i].data[3]);
+ pc_ptr += 3;
+ color_ptr += 3;
+ }
+
+ // Object for storing the normals.
+ pcl::PointCloud::Ptr normals(new pcl::PointCloud);
+ // Object for storing the SHOT descriptors for each point.
+ pcl::PointCloud::Ptr descriptors(new pcl::PointCloud());
+
+ // Note: you would usually perform downsampling now. It has been omitted here
+ // for simplicity, but be aware that computation can take a long time.
+
+ // Estimate the normals.
+ pcl::NormalEstimation normalEstimation;
+ normalEstimation.setInputCloud(cloud);
+ normalEstimation.setRadiusSearch(normal_r);
+ // normalEstimation.setKSearch(40);
+ pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree);
+ normalEstimation.setSearchMethod(kdtree);
+ normalEstimation.compute(*normals);
+
+ // SHOT estimation object.
+ pcl::SHOTColorEstimation shot;
+ shot.setInputCloud(cloud);
+ shot.setInputNormals(normals);
+ // The radius that defines which of the keypoint's neighbors are described.
+ // If too large, there may be clutter, and if too small, not enough points may be found.
+ shot.setRadiusSearch(shot_r);
+ shot.compute(*descriptors);
+
+ auto result = py::array_t(descriptors->points.size() * 1344);
+ auto buf = result.request();
+ float *ptr = (float*)buf.ptr;
+
+ for (int i = 0; i < descriptors->points.size(); ++i)
+ {
+ std::copy(&descriptors->points[i].descriptor[0], &descriptors->points[i].descriptor[1344], &ptr[i * 1344]);
+ }
+ return result;
+}
+
+
+PYBIND11_MODULE(shot, m) {
+ pcl::console::setVerbosityLevel(pcl::console::L_ALWAYS);
+ m.def("compute", &compute, py::arg("pc"), py::arg("normal_r")=0.1, py::arg("shot_r")=0.17);
+ m.def("compute_color", &compute_color, py::arg("pc"), py::arg("pc_color"), py::arg("normal_r")=0.1, py::arg("shot_r")=0.17);
+ m.def("estimate_normal", &estimate_normal, py::arg("pc"), py::arg("normal_r")=0.1);
+}
diff --git a/temp_data/.gitkeep b/temp_data/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..517daa2
--- /dev/null
+++ b/test.py
@@ -0,0 +1,387 @@
+from typing import Dict, Union
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import logging
+import os
+import time
+import tqdm
+from itertools import combinations
+import numpy as np
+import torch
+import torch.nn as nn
+from sklearn.cluster import KMeans
+
+from datasets.rconfmask_afford_point_tuple_dataset import ArticulationDataset
+from models.roartnet import create_shot_encoder, create_encoder
+from inference import voting_translation, voting_rotation
+from utilities.metrics_utils import calc_translation_error, calc_translation_error_batch, calc_direction_error, calc_direction_error_batch, log_metrics
+from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting, visualize_confidence_voting
+from utilities.env_utils import setup_seed
+from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color
+
+
+def test_fn(test_dataloader:torch.utils.data.DataLoader, has_rgb:bool, shot_encoder:nn.Module, encoder:nn.Module,
+ resolution:float, voting_num:int, rot_bin_num:int, angle_tol:float,
+ translation2pc:bool, multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool,
+ rotation_multi_neighbor:bool, neighbor_threshold:float,
+ bmm_size:int, test_num:int, device:int, vis:bool=False) -> Dict[str, Union[np.ndarray, int]]:
+ if rotation_cluster:
+ kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto')
+ else:
+ kmeans = None
+ rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi))
+
+ tested_num = 0
+ with torch.no_grad():
+ translation_distance_errors = []
+ translation_along_errors = []
+ translation_perp_errors = []
+ translation_plane_errors = []
+ translation_line_errors = []
+ translation_outliers = []
+ rotation_errors = []
+ rotation_outliers = []
+ affordance_errors = []
+ affordance_outliers = []
+ for batch_data in tqdm.tqdm(test_dataloader):
+ if tested_num >= test_num:
+ break
+ if has_rgb:
+ pcs, pc_normals, pc_shots, pc_colors, joint_translations, joint_rotations, affordable_positions, _, _, _, _, point_idxs_all = batch_data
+ pcs, pc_normals, pc_shots, pc_colors, point_idxs_all = \
+ pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), pc_colors.cuda(device), point_idxs_all.cuda(device)
+ else:
+ pcs, pc_normals, pc_shots, joint_translations, joint_rotations, affordable_positions, _, _, _, _, point_idxs_all = batch_data
+ pcs, pc_normals, pc_shots, point_idxs_all = \
+ pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), point_idxs_all.cuda(device)
+ # (B, N, 3), (B, N, 3), (B, N, 352)(, (B, N, 3)), (B, J, 3), (B, J, 3), (B, J, 3), (B, N_t, 2 + N_m)
+ B = pcs.shape[0]
+ N = pcs.shape[1]
+ J = joint_translations.shape[1]
+ N_t = point_idxs_all.shape[1]
+ tested_num += B
+
+ # shot encoder for every point
+ shot_feat = shot_encoder(pc_shots) # (B, N, N_s)
+
+ # encoder for sampled point tuples
+ # shot_inputs = torch.cat([shot_feat[point_idxs_all[:, i]] for i in range(0, point_idxs_all.shape[-1])], -1) # (sample_points, feature_dim * (2 + num_more))
+ # normal_inputs = torch.cat([torch.max(torch.sum(normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True),
+ # torch.sum(-normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True))
+ # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, (2+num_more \choose 2))
+ # coord_inputs = torch.cat([pc[point_idxs_all[:, i]] - pc[point_idxs_all[:, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, 3 * (2+num_more \choose 2))
+ # shot_inputs = []
+ # normal_inputs = []
+ # coord_inputs = []
+ # for b in range(pcs.shape[0]):
+ # shot_inputs.append(torch.cat([shot_feat[b][point_idxs_all[b, :, i]] for i in range(0, point_idxs_all.shape[-1])], dim=-1)) # (sample_points, feature_dim * (2 + num_more))
+ # normal_inputs.append(torch.cat([torch.max(torch.sum(normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True),
+ # torch.sum(-normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True))
+ # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, (2+num_more \choose 2))
+ # coord_inputs.append(torch.cat([pcs[b][point_idxs_all[b, :, i]] - pcs[b][point_idxs_all[b, :, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, 3 * (2+num_more \choose 2))
+ # shot_inputs = torch.stack(shot_inputs, dim=0) # (B, sample_points, feature_dim * (2 + num_more))
+ # normal_inputs = torch.stack(normal_inputs, dim=0) # (B, sample_points, (2+num_more \choose 2))
+ # coord_inputs = torch.stack(coord_inputs, dim=0) # (B, sample_points, 3 * (2+num_more \choose 2))
+ shot_inputs = torch.cat([
+ torch.gather(shot_feat, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, shot_feat.shape[-1])))
+ for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, N_s * (2 + N_m))
+ normal_inputs = torch.cat([torch.max(
+ torch.sum(torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True),
+ torch.sum(-torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True))
+ for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, (2+N_m \choose 2))
+ coord_inputs = torch.cat([
+ torch.gather(pcs, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pcs.shape[-1]))) -
+ torch.gather(pcs, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pcs.shape[-1])))
+ for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, 3 * (2+N_m \choose 2))
+ if has_rgb:
+ rgb_inputs = torch.cat([
+ torch.gather(pc_colors, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_colors.shape[-1])))
+ for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, 3 * (2 + N_m))
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1)
+ else:
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1)
+ preds = encoder(inputs) # (B, N_t, (2 + N_r + 2 + 1) * J)
+
+ # voting
+ batch_pred_translations, batch_pred_rotations, batch_pred_affordances = [], [], []
+ pcs_numpy = pcs.cpu().numpy().astype(np.float32) # (B, N, 3)
+ pc_normals_numpy = pc_normals.cpu().numpy().astype(np.float32) # (B, N, 3)
+ joint_translations_numpy = joint_translations.numpy().astype(np.float32) # (B, J, 3)
+ joint_rotations_numpy = joint_rotations.numpy().astype(np.float32) # (B, J, 3)
+ affordable_positions_numpy = affordable_positions.numpy().astype(np.float32) # (B, J, 3)
+ point_idxs_numpy = point_idxs_all[:, :, :2].cpu().numpy().astype(np.int32) # (B, N_t, 2)
+ preds_numpy = preds.cpu().numpy().astype(np.float32) # (B, N_t, (2 + N_r + 2 + 1) * J)
+ for b in range(B):
+ pc = pcs_numpy[b] # (N, 3)
+ pc_normal = pc_normals_numpy[b] # (N, 3)
+ joint_translation = joint_translations_numpy[b] # (J, 3)
+ joint_rotation = joint_rotations_numpy[b] # (J, 3)
+ affordable_position = affordable_positions_numpy[b] # (J, 3)
+ point_idx = point_idxs_numpy[b] # (N_t, 2)
+ pred = preds_numpy[b] # (N_t, (2 + N_r + 2 + 1) * J)
+ pred_tensor = torch.from_numpy(pred)
+
+ pred_translations, pred_rotations, pred_affordances = [], [], []
+ for j in range(J):
+ # conf selection
+ pred_conf = torch.sigmoid(pred_tensor[:, -1*J+j]) # (N_t,)
+ not_selected_indices = pred_conf < 0.5
+ pred_conf[not_selected_indices] = 0
+ # pred_conf[pred_conf > 0] = 1
+ # pred_conf[:] = 1
+ pred_conf = pred_conf.numpy()
+ if vis:
+ visualize_confidence_voting(pred_conf, pc, point_idx,
+ whether_frame=True, whether_bbox=True, window_name='conf_voting')
+ import pdb; pdb.set_trace()
+
+ # translation voting
+ pred_tr = pred[:, 2*j:2*(j+1)] # (N_t, 2)
+ pred_translation, grid_obj, corners = voting_translation(pc, pred_tr, point_idx, pred_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_translations.append(pred_translation)
+
+ # rotation voting
+ pred_rot = pred_tensor[:, (2*J+rot_bin_num*j):(2*J+rot_bin_num*(j+1))] # (N_t, rot_bin_num)
+ pred_rot = torch.softmax(pred_rot, dim=-1)
+ pred_rot = torch.multinomial(pred_rot, 1).float()[:, 0] # (N_t,)
+ pred_rot = pred_rot / (rot_bin_num - 1) * np.pi
+ pred_rot = pred_rot.numpy()
+ pred_direction, sphere_pts, counts = voting_rotation(pc, pred_rot, point_idx, pred_conf,
+ rot_candidate_num, angle_tol, voting_num, bmm_size, device,
+ multi_candidate, candidate_threshold, rotation_cluster, kmeans,
+ rotation_multi_neighbor, neighbor_threshold)
+ pred_rotations.append(pred_direction)
+
+ # affordance voting
+ pred_afford = pred[:, (2*J+rot_bin_num*J+2*j):(2*J+rot_bin_num*J+2*(j+1))] # (N_t, 2)
+ pred_affordance, agrid_obj, acorners = voting_translation(pc, pred_afford, point_idx, pred_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_affordances.append(pred_affordance)
+
+ translation_errors = calc_translation_error(pred_translation, joint_translation[j], pred_direction, joint_rotation[j])
+ if sum(translation_errors) > 20:
+ translation_outliers.append(translation_errors)
+ if vis and sum(translation_errors) > 20:
+ print(f"{translation_errors = }")
+ indices = np.indices(grid_obj.shape)
+ indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape))
+ votes_list = grid_obj.reshape(-1)
+ grid_pc = corners[0] + indices_list * resolution
+ visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color,
+ gt_translation=joint_translation[j], gt_color=dark_green_color,
+ pred_translation=pred_translation, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='tr_voting')
+ import pdb; pdb.set_trace()
+ direction_error = calc_direction_error(pred_direction, joint_rotation[j])
+ if direction_error > 5:
+ rotation_outliers.append(direction_error)
+ if vis and direction_error > 5:
+ print(f"{direction_error = }")
+ visualize_rotation_voting(sphere_pts, counts, pc, pc_color=light_blue_color,
+ gt_rotation=joint_rotation[j], gt_color=dark_green_color,
+ pred_rotation=pred_direction, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='rot_voting')
+ import pdb; pdb.set_trace()
+ affordance_error, _, _, _, _ = calc_translation_error(pred_affordance, affordable_position[j], None, None)
+ if affordance_error > 5:
+ affordance_outliers.append(affordance_error)
+ if vis and affordance_error > 5:
+ print(f"{affordance_error = }")
+ indices = np.indices(agrid_obj.shape)
+ indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(agrid_obj.shape))
+ votes_list = agrid_obj.reshape(-1)
+ grid_pc = acorners[0] + indices_list * resolution
+ visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color,
+ gt_translation=affordable_position[j], gt_color=dark_green_color,
+ pred_translation=pred_affordance, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='afford_voting')
+ import pdb; pdb.set_trace()
+ if vis:
+ visualize(pc, pc_color=light_blue_color, pc_normal=pc_normal,
+ joint_translations=np.array(pred_translations), joint_rotations=np.array(pred_rotations), affordable_positions=np.array(pred_affordances),
+ joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color,
+ whether_frame=True, whether_bbox=True, window_name='pred')
+ import pdb; pdb.set_trace()
+ batch_pred_translations.append(pred_translations)
+ batch_pred_rotations.append(pred_rotations)
+ batch_pred_affordances.append(pred_affordances)
+ batch_pred_translations = np.array(batch_pred_translations).astype(np.float32) # (B, J, 3)
+ batch_pred_rotations = np.array(batch_pred_rotations).astype(np.float32) # (B, J, 3)
+ batch_pred_affordances = np.array(batch_pred_affordances).astype(np.float32) # (B, J, 3)
+ batch_gt_translations = joint_translations.numpy().astype(np.float32) # (B, J, 3)
+ batch_gt_rotations = joint_rotations.numpy().astype(np.float32) # (B, J, 3)
+ batch_gt_affordances = affordable_positions.numpy().astype(np.float32) # (B, J, 3)
+ batch_translation_errors = calc_translation_error_batch(batch_pred_translations, batch_gt_translations, batch_pred_rotations, batch_gt_rotations) # (B, J)
+ batch_rotation_errors = calc_direction_error_batch(batch_pred_rotations, batch_gt_rotations) # (B, J)
+ batch_affordance_errors, _, _, _, _ = calc_translation_error_batch(batch_pred_affordances, batch_gt_affordances, None, None) # (B, J)
+ translation_distance_errors.append(batch_translation_errors[0])
+ translation_along_errors.append(batch_translation_errors[1])
+ translation_perp_errors.append(batch_translation_errors[2])
+ translation_plane_errors.append(batch_translation_errors[3])
+ translation_line_errors.append(batch_translation_errors[4])
+ rotation_errors.append(batch_rotation_errors)
+ affordance_errors.append(batch_affordance_errors)
+ translation_distance_errors = np.concatenate(translation_distance_errors, axis=0) # (tested_num, J)
+ translation_along_errors = np.concatenate(translation_along_errors, axis=0) # (tested_num, J)
+ translation_perp_errors = np.concatenate(translation_perp_errors, axis=0) # (tested_num, J)
+ translation_plane_errors = np.concatenate(translation_plane_errors, axis=0) # (tested_num, J)
+ translation_line_errors = np.concatenate(translation_line_errors, axis=0) # (tested_num, J)
+ rotation_errors = np.concatenate(rotation_errors, axis=0) # (tested_num, J)
+ affordance_errors = np.concatenate(affordance_errors, axis=0) # (tested_num, J)
+
+ return {
+ 'translation_distance_errors': translation_distance_errors,
+ 'translation_along_errors': translation_along_errors,
+ 'translation_perp_errors': translation_perp_errors,
+ 'translation_plane_errors': translation_plane_errors,
+ 'translation_line_errors': translation_line_errors,
+ 'translation_outliers_num': len(translation_outliers),
+ 'rotation_errors': rotation_errors,
+ 'rotation_outliers_num': len(rotation_outliers),
+ 'affordance_errors': affordance_errors,
+ 'affordance_outliers_num': len(affordance_outliers)
+ }
+
+
+@hydra.main(config_path='./configs', config_name='test_config', version_base='1.2')
+def test(cfg:DictConfig) -> None:
+ logger = logging.getLogger('test')
+ hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
+ output_dir = hydra_cfg['runtime']['output_dir']
+ setup_seed(seed=cfg.testing.seed)
+ trained_path = cfg.trained.path
+ trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml")
+ # merge trained_cfg into cfg, cfg has higher priority
+ cfg = OmegaConf.merge(trained_cfg, cfg)
+ print(OmegaConf.to_yaml(cfg))
+
+ # prepare dataset
+ logger.info("Preparing dataset...")
+ device = cfg.testing.device
+ training_path = cfg.dataset.train_path
+ testing_path = cfg.dataset.test_path
+ training_categories = cfg.dataset.train_categories
+ testing_categories = cfg.dataset.test_categories
+ joint_num = cfg.dataset.joint_num
+ resolution = cfg.dataset.resolution
+ receptive_field = cfg.dataset.receptive_field
+ noise = cfg.dataset.noise
+ distortion_rate = cfg.dataset.distortion_rate
+ distortion_level = cfg.dataset.distortion_level
+ outlier_rate = cfg.dataset.outlier_rate
+ outlier_level = cfg.dataset.outlier_level
+ rgb = cfg.dataset.rgb
+ denoise = cfg.dataset.denoise
+ normalize = cfg.dataset.normalize
+ sample_points_num = cfg.dataset.sample_points_num
+ sample_tuples_num = cfg.algorithm.sampling.sample_tuples_num
+ tuple_more_num = cfg.algorithm.sampling.tuple_more_num
+ training_dataset = ArticulationDataset(training_path, training_categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ noise, distortion_rate, distortion_level, outlier_rate, outlier_level,
+ rgb, denoise, normalize, debug=False, vis=False, is_train=False)
+
+ batch_size = cfg.testing.batch_size
+ num_workers = cfg.testing.num_workers
+ training_dataloader = torch.utils.data.DataLoader(training_dataset, pin_memory=True, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+ testing_dataset = ArticulationDataset(testing_path, testing_categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ noise, distortion_rate, distortion_level, outlier_rate, outlier_level,
+ rgb, denoise, normalize, debug=False, vis=False, is_train=False)
+ testing_dataloader = torch.utils.data.DataLoader(testing_dataset, pin_memory=True, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+ logger.info("Prepared dataset.")
+
+ # prepare model
+ logger.info("Preparing model...")
+ shot_hidden_dims = cfg.algorithm.shot_encoder.hidden_dims
+ shot_feature_dim = cfg.algorithm.shot_encoder.feature_dim
+ shot_bn = cfg.algorithm.shot_encoder.bn
+ shot_ln = cfg.algorithm.shot_encoder.ln
+ shot_dropout = cfg.algorithm.shot_encoder.dropout
+ shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim,
+ shot_bn, shot_ln, shot_dropout)
+ shot_encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/shot_encoder_latest.pth', map_location=torch.device(device)))
+ shot_encoder = shot_encoder.cuda(device)
+ overall_hidden_dims = cfg.algorithm.encoder.hidden_dims
+ rot_bin_num = cfg.algorithm.voting.rot_bin_num
+ overall_bn = cfg.algorithm.encoder.bn
+ overall_ln = cfg.algorithm.encoder.ln
+ overall_dropout = cfg.algorithm.encoder.dropout
+ encoder = create_encoder(tuple_more_num, shot_feature_dim, rgb, overall_hidden_dims, rot_bin_num, joint_num,
+ overall_bn, overall_ln, overall_dropout)
+ encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/encoder_latest.pth', map_location=torch.device(device)))
+ encoder = encoder.cuda(device)
+ logger.info("Prepared model.")
+
+ # testing
+ voting_num = cfg.algorithm.voting.voting_num
+ angle_tol = cfg.algorithm.voting.angle_tol
+ translation2pc = cfg.algorithm.voting.translation2pc
+ multi_candidate = cfg.algorithm.voting.multi_candidate
+ candidate_threshold = cfg.algorithm.voting.candidate_threshold
+ rotation_multi_neighbor = cfg.algorithm.voting.rotation_multi_neighbor
+ neighbor_threshold = cfg.algorithm.voting.neighbor_threshold
+ rotation_cluster = cfg.algorithm.voting.rotation_cluster
+ bmm_size = cfg.algorithm.voting.bmm_size
+ logger.info("Testing...")
+ testing_testing_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ testing_testing_results = test_fn(testing_dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, len(testing_dataset), device, vis=cfg.vis)
+ log_metrics(testing_testing_results, logger, output_dir, tb_writer=None)
+
+ testing_testing_end_time = time.time()
+ logger.info("Tested.")
+ logger.info("Testing time: " + str(testing_testing_end_time - testing_testing_start_time))
+
+ if cfg.testing.training:
+ logger.info("Testing training...")
+ testing_training_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ testing_training_results = test_fn(training_dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, len(training_dataset), device, vis=cfg.vis)
+ log_metrics(testing_training_results, logger, output_dir, tb_writer=None)
+
+ testing_training_end_time = time.time()
+ logger.info("Tested training.")
+ logger.info("Testing training time: " + str(testing_training_end_time - testing_training_start_time))
+ else:
+ pass
+
+
+if __name__ == '__main__':
+ test()
diff --git a/test_gt.py b/test_gt.py
new file mode 100644
index 0000000..aef0c53
--- /dev/null
+++ b/test_gt.py
@@ -0,0 +1,209 @@
+import hydra
+from omegaconf import DictConfig
+import logging
+import tqdm
+import time
+import numpy as np
+import torch
+from sklearn.cluster import KMeans
+
+from datasets.rconfmask_afford_point_tuple_dataset import ArticulationDataset
+from inference import voting_translation, voting_rotation
+from utilities.metrics_utils import calc_translation_error, calc_translation_error_batch, calc_direction_error, calc_direction_error_batch, log_metrics
+from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting
+from utilities.env_utils import setup_seed
+from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color
+
+
+@hydra.main(config_path='./configs', config_name='test_gt_config', version_base='1.2')
+def test_gt(cfg:DictConfig) -> None:
+ logger = logging.getLogger('test_gt')
+ hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
+ output_dir = hydra_cfg['runtime']['output_dir']
+ setup_seed(seed=cfg.general.seed)
+
+ # prepare dataset
+ logger.info("Preparing dataset...")
+ device = cfg.general.device
+ path = cfg.dataset.path
+ categories = cfg.dataset.categories
+ joint_num = cfg.dataset.joint_num
+ resolution = cfg.dataset.resolution
+ receptive_field = cfg.dataset.receptive_field
+ denoise = cfg.dataset.denoise
+ normalize = cfg.dataset.normalize
+ sample_points_num = cfg.dataset.sample_points_num
+ sample_tuples_num = cfg.algorithm.sample_tuples_num
+ tuple_more_num = cfg.algorithm.tuple_more_num
+ dataset = ArticulationDataset(path, categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ rgb=False, denoise=denoise, normalize=normalize,
+ debug=False, vis=False, is_train=False)
+ batch_size = cfg.general.batch_size
+ num_workers = cfg.general.num_workers
+ dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+ logger.info("Prepared dataset.")
+
+ # test
+ logger.info("GT Testing...")
+ translation2pc = cfg.algorithm.translation2pc # solve translation voting too much far problem, but maybe sacrifice precision
+ rotation_cluster = cfg.algorithm.rotation_cluster # solve rotation voting opposite problem, but still cannot solve it completely since maybe only 1 or 2 candidates
+ if rotation_cluster:
+ kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto')
+ else:
+ kmeans = None
+ debug = False
+ multi_candidate = cfg.algorithm.multi_candidate # solve translation and rotation voting discrete problem
+ candidate_threshold = cfg.algorithm.candidate_threshold
+ rotation_multi_neighbor = cfg.algorithm.rotation_multi_neighbor # solve rotation voting discrete resolution problem
+ neighbor_threshold = cfg.algorithm.neighbor_threshold
+ angle_tol = cfg.algorithm.angle_tol # solve rotation voting discrete resolution problem, but introduce voting opposite problem instead
+ rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi))
+ voting_num = cfg.algorithm.voting_num
+ bmm_size = cfg.algorithm.bmm_size
+ with torch.no_grad():
+ translation_distance_errors = []
+ translation_along_errors = []
+ translation_perp_errors = []
+ translation_plane_errors = []
+ translation_line_errors = []
+ translation_outliers = []
+ rotation_errors = []
+ rotation_outliers = []
+ affordance_errors = []
+ affordance_outliers = []
+ test_gt_start_time = time.time()
+ for pcs, pc_normals, pc_shots, joint_translations, joint_rotations, affordable_positions, targets_tr, targets_rot, targets_afford, targets_conf, point_idxs_all in tqdm.tqdm(dataloader):
+ # (B, N, 3), (B, N, 3), (B, N, 352), (B, J, 3), (B, J, 3), (B, J, 3), (B, J, N_t, 2), (B, J, N_t), (B, J, N_t, 2), (B, J, N_t), (B, N_t, 2 + N_m)
+ B = pcs.shape[0]
+ N_t = targets_tr.shape[2]
+ batch_pred_translations, batch_pred_rotations, batch_pred_affordances = [], [], []
+ for b in range(B):
+ pc = pcs[b].numpy().astype(np.float32) # (N, 3)
+ # pc_normal = pc_normals[b].numpy().astype(np.float32) # (N, 3)
+ # pc_shot = pc_shots[b].numpy().astype(np.float32) # (N, 352)
+ joint_translation = joint_translations[b].numpy().astype(np.float32) # (J, 3)
+ joint_rotation = joint_rotations[b].numpy().astype(np.float32) # (J, 3)
+ affordable_position = affordable_positions[b].numpy().astype(np.float32) # (J, 3)
+ target_tr = targets_tr[b].numpy().astype(np.float32) # (J, N_t, 2)
+ target_rot = targets_rot[b].numpy().astype(np.float32) # (J, N_t)
+ target_afford = targets_afford[b].numpy().astype(np.float32) # (J, N_t, 2)
+ target_conf = targets_conf[b].numpy().astype(np.float32) # (J, N_t)
+ point_idx_all = point_idxs_all[b].numpy().astype(np.int32) # (N_t, 2 + N_m)
+
+ # inference
+ pred_translations, pred_rotations, pred_affordances = [], [], []
+ for j in range(joint_num):
+ this_target_tr = target_tr[j] # (N_t, 2)
+ this_target_rot = target_rot[j] # (N_t,)
+ this_target_afford = target_afford[j] # (N_t, 2)
+ this_target_conf = target_conf[j] # (N_t,)
+
+ pred_translation, grid_obj, corners = voting_translation(pc, this_target_tr, point_idx_all[:, :2], this_target_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_translations.append(pred_translation)
+
+ pred_direction, sphere_pts, counts = voting_rotation(pc, this_target_rot, point_idx_all[:, :2], this_target_conf,
+ rot_candidate_num, angle_tol, voting_num, bmm_size, device,
+ multi_candidate, candidate_threshold, rotation_cluster, kmeans,
+ rotation_multi_neighbor, neighbor_threshold)
+ pred_rotations.append(pred_direction)
+
+ pred_affordance, agrid_obj, acorners = voting_translation(pc, this_target_afford, point_idx_all[:, :2], this_target_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_affordances.append(pred_affordance)
+
+ translation_errors = calc_translation_error(pred_translation, joint_translation[j], pred_direction, joint_rotation[j])
+ if sum(translation_errors) > 20:
+ translation_outliers.append(translation_errors)
+ if debug and sum(translation_errors) > 20:
+ print(f"{translation_errors = }")
+ indices = np.indices(grid_obj.shape)
+ indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape))
+ votes_list = grid_obj.reshape(-1)
+ grid_pc = corners[0] + indices_list * resolution
+ visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color,
+ gt_translation=joint_translation[j], gt_color=dark_green_color,
+ pred_translation=pred_translation, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='tr_voting')
+ import pdb; pdb.set_trace()
+ direction_error = calc_direction_error(pred_direction, joint_rotation[j])
+ if direction_error > 5:
+ rotation_outliers.append(direction_error)
+ if debug and direction_error > 5:
+ print(f"{direction_error = }")
+ visualize_rotation_voting(sphere_pts, counts, pc, pc_color=light_blue_color,
+ gt_rotation=joint_rotation[j], gt_color=dark_green_color,
+ pred_rotation=pred_direction, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='rot_voting')
+ import pdb; pdb.set_trace()
+ affordance_error, _, _, _, _ = calc_translation_error(pred_affordance, affordable_position[j], None, None)
+ if affordance_error > 5:
+ affordance_outliers.append(affordance_error)
+ if debug and affordance_error > 5:
+ print(f"{affordance_error = }")
+ indices = np.indices(agrid_obj.shape)
+ indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(agrid_obj.shape))
+ votes_list = agrid_obj.reshape(-1)
+ grid_pc = acorners[0] + indices_list * resolution
+ visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color,
+ gt_translation=affordable_position[j], gt_color=dark_green_color,
+ pred_translation=pred_affordance, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='afford_voting')
+ import pdb; pdb.set_trace()
+ print(np.sum(this_target_conf), np.sum(this_target_conf) / N_t, sum(translation_errors), direction_error, affordance_error)
+ # if debug:
+ # visualize(pc, pc_color=light_blue_color, pc_normal=pc_normal,
+ # joint_translations=np.array(pred_translations), joint_rotations=np.array(pred_rotations), affordable_positions=np.array(pred_affordances),
+ # joint_axis_colors=red_color, joint_point_colors=dark_red_color,
+ # whether_frame=True, whether_bbox=True, window_name='pred')
+ # import pdb; pdb.set_trace()
+ batch_pred_translations.append(pred_translations)
+ batch_pred_rotations.append(pred_rotations)
+ batch_pred_affordances.append(pred_affordances)
+ batch_pred_translations = np.array(batch_pred_translations).astype(np.float32) # (B, J, 3)
+ batch_pred_rotations = np.array(batch_pred_rotations).astype(np.float32) # (B, J, 3)
+ batch_pred_affordances = np.array(batch_pred_affordances).astype(np.float32) # (B, J, 3)
+ batch_gt_translations = joint_translations.numpy().astype(np.float32) # (B, J, 3)
+ batch_gt_rotations = joint_rotations.numpy().astype(np.float32) # (B, J, 3)
+ batch_gt_affordances = affordable_positions.numpy().astype(np.float32) # (B, J, 3)
+ batch_translation_errors = calc_translation_error_batch(batch_pred_translations, batch_gt_translations, batch_pred_rotations, batch_gt_rotations) # (B, J)
+ batch_rotation_errors = calc_direction_error_batch(batch_pred_rotations, batch_gt_rotations) # (B, J)
+ batch_affordance_errors, _, _, _, _ = calc_translation_error_batch(batch_pred_affordances, batch_gt_affordances, None, None) # (B, J)
+ translation_distance_errors.append(batch_translation_errors[0])
+ translation_along_errors.append(batch_translation_errors[1])
+ translation_perp_errors.append(batch_translation_errors[2])
+ translation_plane_errors.append(batch_translation_errors[3])
+ translation_line_errors.append(batch_translation_errors[4])
+ rotation_errors.append(batch_rotation_errors)
+ affordance_errors.append(batch_affordance_errors)
+ test_gt_end_time = time.time()
+ translation_distance_errors = np.concatenate(translation_distance_errors, axis=0)
+ translation_along_errors = np.concatenate(translation_along_errors, axis=0)
+ translation_perp_errors = np.concatenate(translation_perp_errors, axis=0)
+ translation_plane_errors = np.concatenate(translation_plane_errors, axis=0)
+ translation_line_errors = np.concatenate(translation_line_errors, axis=0)
+ rotation_errors = np.concatenate(rotation_errors, axis=0)
+ affordance_errors = np.concatenate(affordance_errors, axis=0)
+
+ results_dict = {
+ 'translation_distance_errors': translation_distance_errors,
+ 'translation_along_errors': translation_along_errors,
+ 'translation_perp_errors': translation_perp_errors,
+ 'translation_plane_errors': translation_plane_errors,
+ 'translation_line_errors': translation_line_errors,
+ 'translation_outliers_num': len(translation_outliers),
+ 'rotation_errors': rotation_errors,
+ 'rotation_outliers_num': len(rotation_outliers),
+ 'affordance_errors': affordance_errors,
+ 'affordance_outliers_num': len(affordance_outliers)
+ }
+ log_metrics(results_dict, logger, output_dir, tb_writer=None)
+ logger.info(f"Time: {test_gt_end_time - test_gt_start_time}")
+ logger.info("GT Tested.")
+
+
+if __name__ == '__main__':
+ test_gt()
diff --git a/test_real.py b/test_real.py
new file mode 100644
index 0000000..6b5e3a6
--- /dev/null
+++ b/test_real.py
@@ -0,0 +1,356 @@
+from typing import Dict, Union
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import logging
+import os
+import time
+import tqdm
+from itertools import combinations
+import numpy as np
+import torch
+import torch.nn as nn
+from sklearn.cluster import KMeans
+
+from datasets.point_tuple_dataset import ArticulationDataset
+from models.roartnet import create_shot_encoder, create_encoder
+from inference import voting_translation, voting_rotation
+from utilities.metrics_utils import calc_translation_error, calc_translation_error_batch, calc_direction_error, calc_direction_error_batch, log_metrics
+from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting, visualize_confidence_voting
+from utilities.env_utils import setup_seed
+from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color
+
+
+def test_fn(test_dataloader:torch.utils.data.DataLoader, has_rgb:bool, shot_encoder:nn.Module, encoder:nn.Module,
+ resolution:float, voting_num:int, rot_bin_num:int, angle_tol:float,
+ translation2pc:bool, multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool,
+ rotation_multi_neighbor:bool, neighbor_threshold:float,
+ bmm_size:int, test_num:int, device:int, vis:bool=False) -> Dict[str, Union[np.ndarray, int]]:
+ if rotation_cluster:
+ kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto')
+ else:
+ kmeans = None
+ rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi))
+
+ tested_num = 0
+ with torch.no_grad():
+ names = []
+ translation_distance_errors = []
+ translation_along_errors = []
+ translation_perp_errors = []
+ translation_plane_errors = []
+ translation_line_errors = []
+ translation_outliers = []
+ rotation_errors = []
+ rotation_outliers = []
+ affordance_errors = []
+ affordance_outliers = []
+ for batch_data in tqdm.tqdm(test_dataloader):
+ if tested_num >= test_num:
+ break
+ if has_rgb:
+ pcs, pc_normals, pc_shots, pc_colors, joint_translations, joint_rotations, affordable_positions, joint_types, point_idxs_all, batch_names = batch_data
+ pcs, pc_normals, pc_shots, pc_colors, point_idxs_all = \
+ pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), pc_colors.cuda(device), point_idxs_all.cuda(device)
+ else:
+ pcs, pc_normals, pc_shots, joint_translations, joint_rotations, affordable_positions, joint_types, point_idxs_all, batch_names = batch_data
+ pcs, pc_normals, pc_shots, point_idxs_all = \
+ pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), point_idxs_all.cuda(device)
+ # (B, N, 3), (B, N, 3), (B, N, 352)(, (B, N, 3)), (B, J, 3), (B, J, 3), (B, J, 3), (B, J), (B, N_t, 2 + N_m), (B,)
+ B = pcs.shape[0]
+ N = pcs.shape[1]
+ J = joint_translations.shape[1]
+ N_t = point_idxs_all.shape[1]
+ tested_num += B
+
+ # shot encoder for every point
+ shot_feat = shot_encoder(pc_shots) # (B, N, N_s)
+
+ # encoder for sampled point tuples
+ # shot_inputs = torch.cat([shot_feat[point_idxs_all[:, i]] for i in range(0, point_idxs_all.shape[-1])], -1) # (sample_points, feature_dim * (2 + num_more))
+ # normal_inputs = torch.cat([torch.max(torch.sum(normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True),
+ # torch.sum(-normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True))
+ # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, (2+num_more \choose 2))
+ # coord_inputs = torch.cat([pc[point_idxs_all[:, i]] - pc[point_idxs_all[:, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, 3 * (2+num_more \choose 2))
+ # shot_inputs = []
+ # normal_inputs = []
+ # coord_inputs = []
+ # for b in range(pcs.shape[0]):
+ # shot_inputs.append(torch.cat([shot_feat[b][point_idxs_all[b, :, i]] for i in range(0, point_idxs_all.shape[-1])], dim=-1)) # (sample_points, feature_dim * (2 + num_more))
+ # normal_inputs.append(torch.cat([torch.max(torch.sum(normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True),
+ # torch.sum(-normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True))
+ # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, (2+num_more \choose 2))
+ # coord_inputs.append(torch.cat([pcs[b][point_idxs_all[b, :, i]] - pcs[b][point_idxs_all[b, :, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, 3 * (2+num_more \choose 2))
+ # shot_inputs = torch.stack(shot_inputs, dim=0) # (B, sample_points, feature_dim * (2 + num_more))
+ # normal_inputs = torch.stack(normal_inputs, dim=0) # (B, sample_points, (2+num_more \choose 2))
+ # coord_inputs = torch.stack(coord_inputs, dim=0) # (B, sample_points, 3 * (2+num_more \choose 2))
+ shot_inputs = torch.cat([
+ torch.gather(shot_feat, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, shot_feat.shape[-1])))
+ for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, N_s * (2 + N_m))
+ normal_inputs = torch.cat([torch.max(
+ torch.sum(torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True),
+ torch.sum(-torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True))
+ for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, (2+N_m \choose 2))
+ coord_inputs = torch.cat([
+ torch.gather(pcs, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pcs.shape[-1]))) -
+ torch.gather(pcs, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pcs.shape[-1])))
+ for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, 3 * (2+N_m \choose 2))
+ if has_rgb:
+ rgb_inputs = torch.cat([
+ torch.gather(pc_colors, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_colors.shape[-1])))
+ for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, 3 * (2 + N_m))
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1)
+ else:
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1)
+ preds = encoder(inputs) # (B, N_t, (2 + N_r + 2 + 1) * J)
+
+ # voting
+ batch_pred_translations, batch_pred_rotations, batch_pred_affordances = [], [], []
+ pcs_numpy = pcs.cpu().numpy().astype(np.float32) # (B, N, 3)
+ pc_normals_numpy = pc_normals.cpu().numpy().astype(np.float32) # (B, N, 3)
+ joint_translations_numpy = joint_translations.numpy().astype(np.float32) # (B, J, 3)
+ joint_rotations_numpy = joint_rotations.numpy().astype(np.float32) # (B, J, 3)
+ affordable_positions_numpy = affordable_positions.numpy().astype(np.float32) # (B, J, 3)
+ point_idxs_numpy = point_idxs_all[:, :, :2].cpu().numpy().astype(np.int32) # (B, N_t, 2)
+ preds_numpy = preds.cpu().numpy().astype(np.float32) # (B, N_t, (2 + N_r + 2 + 1) * J)
+ for b in range(B):
+ pc = pcs_numpy[b] # (N, 3)
+ pc_normal = pc_normals_numpy[b] # (N, 3)
+ joint_translation = joint_translations_numpy[b] # (J, 3)
+ joint_rotation = joint_rotations_numpy[b] # (J, 3)
+ affordable_position = affordable_positions_numpy[b] # (J, 3)
+ point_idx = point_idxs_numpy[b] # (N_t, 2)
+ pred = preds_numpy[b] # (N_t, (2 + N_r + 2 + 1) * J)
+ pred_tensor = torch.from_numpy(pred)
+
+ pred_translations, pred_rotations, pred_affordances = [], [], []
+ for j in range(J):
+ # conf selection
+ pred_conf = torch.sigmoid(pred_tensor[:, -1*J+j]) # (N_t,)
+ not_selected_indices = pred_conf < 0.5
+ pred_conf[not_selected_indices] = 0
+ # pred_conf[pred_conf > 0] = 1
+ pred_conf = pred_conf.numpy()
+ if vis:
+ visualize_confidence_voting(pred_conf, pc, point_idx,
+ whether_frame=True, whether_bbox=True, window_name='conf_voting')
+ import pdb; pdb.set_trace()
+
+ # translation voting
+ pred_tr = pred[:, 2*j:2*(j+1)] # (N_t, 2)
+ pred_translation, grid_obj, corners = voting_translation(pc, pred_tr, point_idx, pred_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_translations.append(pred_translation)
+
+ # rotation voting
+ pred_rot = pred_tensor[:, (2*J+rot_bin_num*j):(2*J+rot_bin_num*(j+1))] # (N_t, rot_bin_num)
+ pred_rot = torch.softmax(pred_rot, dim=-1)
+ pred_rot = torch.multinomial(pred_rot, 1).float()[:, 0] # (N_t,)
+ pred_rot = pred_rot / (rot_bin_num - 1) * np.pi
+ pred_rot = pred_rot.numpy()
+ pred_direction, sphere_pts, counts = voting_rotation(pc, pred_rot, point_idx, pred_conf,
+ rot_candidate_num, angle_tol, voting_num, bmm_size, device,
+ multi_candidate, candidate_threshold, rotation_cluster, kmeans,
+ rotation_multi_neighbor, neighbor_threshold)
+ pred_rotations.append(pred_direction)
+
+ # affordance voting
+ pred_afford = pred[:, (2*J+rot_bin_num*J+2*j):(2*J+rot_bin_num*J+2*(j+1))] # (N_t, 2)
+ pred_affordance, agrid_obj, acorners = voting_translation(pc, pred_afford, point_idx, pred_conf,
+ resolution, voting_num, device,
+ translation2pc, multi_candidate, candidate_threshold)
+ pred_affordances.append(pred_affordance)
+
+ translation_errors = calc_translation_error(pred_translation, joint_translation[j], pred_direction, joint_rotation[j])
+ if sum(translation_errors) > 20:
+ translation_outliers.append(translation_errors)
+ if vis and sum(translation_errors) > 20:
+ print(f"{translation_errors = }")
+ indices = np.indices(grid_obj.shape)
+ indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape))
+ votes_list = grid_obj.reshape(-1)
+ grid_pc = corners[0] + indices_list * resolution
+ visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color,
+ gt_translation=joint_translation[j], gt_color=dark_green_color,
+ pred_translation=pred_translation, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='tr_voting')
+ import pdb; pdb.set_trace()
+ direction_error = calc_direction_error(pred_direction, joint_rotation[j])
+ if direction_error > 5:
+ rotation_outliers.append(direction_error)
+ if vis and direction_error > 5:
+ print(f"{direction_error = }")
+ visualize_rotation_voting(sphere_pts, counts, pc, pc_color=light_blue_color,
+ gt_rotation=joint_rotation[j], gt_color=dark_green_color,
+ pred_rotation=pred_direction, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='rot_voting')
+ import pdb; pdb.set_trace()
+ affordance_error, _, _, _, _ = calc_translation_error(pred_affordance, affordable_position[j], None, None)
+ if affordance_error > 5:
+ affordance_outliers.append(affordance_error)
+ if vis and affordance_error > 5:
+ print(f"{affordance_error = }")
+ indices = np.indices(agrid_obj.shape)
+ indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(agrid_obj.shape))
+ votes_list = agrid_obj.reshape(-1)
+ grid_pc = acorners[0] + indices_list * resolution
+ visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color,
+ gt_translation=affordable_position[j], gt_color=dark_green_color,
+ pred_translation=pred_affordance, pred_color=yellow_color,
+ show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='afford_voting')
+ import pdb; pdb.set_trace()
+ if vis:
+ visualize(pc, pc_color=light_blue_color, pc_normal=pc_normal,
+ joint_translations=np.array(pred_translations), joint_rotations=np.array(pred_rotations), affordable_positions=np.array(pred_affordances),
+ joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color,
+ whether_frame=True, whether_bbox=True, window_name='pred')
+ import pdb; pdb.set_trace()
+ batch_pred_translations.append(pred_translations)
+ batch_pred_rotations.append(pred_rotations)
+ batch_pred_affordances.append(pred_affordances)
+ batch_pred_translations = np.array(batch_pred_translations).astype(np.float32) # (B, J, 3)
+ batch_pred_rotations = np.array(batch_pred_rotations).astype(np.float32) # (B, J, 3)
+ batch_pred_affordances = np.array(batch_pred_affordances).astype(np.float32) # (B, J, 3)
+ batch_gt_translations = joint_translations.numpy().astype(np.float32) # (B, J, 3)
+ batch_gt_rotations = joint_rotations.numpy().astype(np.float32) # (B, J, 3)
+ batch_gt_affordances = affordable_positions.numpy().astype(np.float32) # (B, J, 3)
+ batch_translation_errors = calc_translation_error_batch(batch_pred_translations, batch_gt_translations, batch_pred_rotations, batch_gt_rotations) # (B, J)
+ batch_rotation_errors = calc_direction_error_batch(batch_pred_rotations, batch_gt_rotations) # (B, J)
+ batch_affordance_errors, _, _, _, _ = calc_translation_error_batch(batch_pred_affordances, batch_gt_affordances, None, None) # (B, J)
+ translation_distance_errors.append(batch_translation_errors[0])
+ translation_along_errors.append(batch_translation_errors[1])
+ translation_perp_errors.append(batch_translation_errors[2])
+ translation_plane_errors.append(batch_translation_errors[3])
+ translation_line_errors.append(batch_translation_errors[4])
+ rotation_errors.append(batch_rotation_errors)
+ affordance_errors.append(batch_affordance_errors)
+ names.extend(batch_names)
+ translation_distance_errors = np.concatenate(translation_distance_errors, axis=0) # (tested_num, J)
+ translation_along_errors = np.concatenate(translation_along_errors, axis=0) # (tested_num, J)
+ translation_perp_errors = np.concatenate(translation_perp_errors, axis=0) # (tested_num, J)
+ translation_plane_errors = np.concatenate(translation_plane_errors, axis=0) # (tested_num, J)
+ translation_line_errors = np.concatenate(translation_line_errors, axis=0) # (tested_num, J)
+ rotation_errors = np.concatenate(rotation_errors, axis=0) # (tested_num, J)
+ affordance_errors = np.concatenate(affordance_errors, axis=0) # (tested_num, J)
+
+ return {
+ 'names': names,
+ 'translation_distance_errors': translation_distance_errors,
+ 'translation_along_errors': translation_along_errors,
+ 'translation_perp_errors': translation_perp_errors,
+ 'translation_plane_errors': translation_plane_errors,
+ 'translation_line_errors': translation_line_errors,
+ 'translation_outliers_num': len(translation_outliers),
+ 'rotation_errors': rotation_errors,
+ 'rotation_outliers_num': len(rotation_outliers),
+ 'affordance_errors': affordance_errors,
+ 'affordance_outliers_num': len(affordance_outliers)
+ }
+
+
+@hydra.main(config_path='./configs', config_name='test_real_config', version_base='1.2')
+def test_real(cfg:DictConfig) -> None:
+ logger = logging.getLogger('test_real')
+ hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
+ output_dir = hydra_cfg['runtime']['output_dir']
+ setup_seed(seed=cfg.testing.seed)
+ trained_path = cfg.trained.path
+ trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml")
+ # merge trained_cfg into cfg, cfg has higher priority
+ cfg = OmegaConf.merge(trained_cfg, cfg)
+ print(OmegaConf.to_yaml(cfg))
+
+ # prepare dataset
+ logger.info("Preparing dataset...")
+ device = cfg.testing.device
+ path = cfg.dataset.path
+ instances = cfg.dataset.instances
+ joint_num = cfg.dataset.joint_num
+ resolution = cfg.dataset.resolution
+ receptive_field = cfg.dataset.receptive_field
+ rgb = cfg.dataset.rgb
+ denoise = cfg.dataset.denoise
+ normalize = cfg.dataset.normalize
+ sample_points_num = cfg.dataset.sample_points_num
+ sample_tuples_num = cfg.algorithm.sampling.sample_tuples_num
+ tuple_more_num = cfg.algorithm.sampling.tuple_more_num
+ dataset = ArticulationDataset(path, instances, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ rgb, denoise, normalize, debug=False, vis=False, is_train=False)
+
+ batch_size = cfg.testing.batch_size
+ num_workers = cfg.testing.num_workers
+ dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+ logger.info("Prepared dataset.")
+
+ # prepare model
+ logger.info("Preparing model...")
+ shot_hidden_dims = cfg.algorithm.shot_encoder.hidden_dims
+ shot_feature_dim = cfg.algorithm.shot_encoder.feature_dim
+ shot_bn = cfg.algorithm.shot_encoder.bn
+ shot_ln = cfg.algorithm.shot_encoder.ln
+ shot_droput = cfg.algorithm.shot_encoder.dropout
+ shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim,
+ shot_bn, shot_ln, shot_droput)
+ shot_encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/shot_encoder_latest.pth', map_location=torch.device(device)))
+ shot_encoder = shot_encoder.cuda(device)
+ overall_hidden_dims = cfg.algorithm.encoder.hidden_dims
+ rot_bin_num = cfg.algorithm.voting.rot_bin_num
+ overall_bn = cfg.algorithm.encoder.bn
+ overall_ln = cfg.algorithm.encoder.ln
+ overall_dropout = cfg.algorithm.encoder.dropout
+ encoder = create_encoder(tuple_more_num, shot_feature_dim, rgb, overall_hidden_dims, rot_bin_num, joint_num,
+ overall_bn, overall_ln, overall_dropout)
+ encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/encoder_latest.pth', map_location=torch.device(device)))
+ encoder = encoder.cuda(device)
+ logger.info("Prepared model.")
+
+ # testing
+ voting_num = cfg.algorithm.voting.voting_num
+ angle_tol = cfg.algorithm.voting.angle_tol
+ translation2pc = cfg.algorithm.voting.translation2pc
+ multi_candidate = cfg.algorithm.voting.multi_candidate
+ candidate_threshold = cfg.algorithm.voting.candidate_threshold
+ rotation_multi_neighbor = cfg.algorithm.voting.rotation_multi_neighbor
+ neighbor_threshold = cfg.algorithm.voting.neighbor_threshold
+ rotation_cluster = cfg.algorithm.voting.rotation_cluster
+ bmm_size = cfg.algorithm.voting.bmm_size
+ logger.info("Testing...")
+ testing_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ testing_results = test_fn(dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, len(dataset), device, vis=cfg.vis)
+ log_metrics(testing_results, logger, output_dir, tb_writer=None)
+
+ testing_end_time = time.time()
+ logger.info("Tested.")
+ logger.info("Testing time: " + str(testing_end_time - testing_start_time))
+
+
+if __name__ == '__main__':
+ test_real()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..6da9818
--- /dev/null
+++ b/train.py
@@ -0,0 +1,378 @@
+import hydra
+from omegaconf import DictConfig
+import logging
+import os
+from itertools import combinations
+import time
+import tqdm
+import numpy as np
+import torch
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.utils.tensorboard import SummaryWriter
+from warmup_scheduler import GradualWarmupScheduler
+
+from datasets.rconfmask_afford_point_tuple_dataset import ArticulationDataset
+from models.roartnet import create_shot_encoder, create_encoder
+from test import test_fn
+from utilities.env_utils import setup_seed
+from utilities.metrics_utils import AverageMeter, log_metrics
+from utilities.data_utils import real2prob
+
+
+@hydra.main(config_path='./configs', config_name='train_config', version_base='1.2')
+def train(cfg:DictConfig) -> None:
+ logger = logging.getLogger('train')
+ hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
+ output_dir = hydra_cfg['runtime']['output_dir']
+ setup_seed(seed=cfg.training.seed)
+
+ # prepare dataset
+ logger.info("Preparing dataset...")
+ device = cfg.training.device
+ training_path = cfg.dataset.train_path
+ training_categories = cfg.dataset.train_categories
+ joint_num = cfg.dataset.joint_num
+ resolution = cfg.dataset.resolution
+ receptive_field = cfg.dataset.receptive_field
+ rgb = cfg.dataset.rgb
+ denoise = cfg.dataset.denoise
+ normalize = cfg.dataset.normalize
+ sample_points_num = cfg.dataset.sample_points_num
+ sample_tuples_num = cfg.algorithm.sampling.sample_tuples_num
+ tuple_more_num = cfg.algorithm.sampling.tuple_more_num
+ training_dataset = ArticulationDataset(training_path, training_categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ rgb, denoise, normalize, debug=False, vis=False, is_train=True)
+
+ batch_size = cfg.training.batch_size
+ num_workers = cfg.training.num_workers
+ training_dataloader = torch.utils.data.DataLoader(training_dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+
+ testing_training_dataset = ArticulationDataset(training_path, training_categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ rgb, denoise, normalize, debug=False, vis=False, is_train=False)
+ testing_training_dataloader = torch.utils.data.DataLoader(testing_training_dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+ testing_path = cfg.dataset.test_path
+ testing_categories = cfg.dataset.test_categories
+ testing_testing_dataset = ArticulationDataset(testing_path, testing_categories, joint_num, resolution, receptive_field,
+ sample_points_num, sample_tuples_num, tuple_more_num,
+ rgb, denoise, normalize, debug=False, vis=False, is_train=False)
+ testing_testing_dataloader = torch.utils.data.DataLoader(testing_testing_dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers)
+ logger.info("Prepared dataset.")
+
+ # prepare model
+ logger.info("Preparing model...")
+ shot_hidden_dims = cfg.algorithm.shot_encoder.hidden_dims
+ shot_feature_dim = cfg.algorithm.shot_encoder.feature_dim
+ shot_bn = cfg.algorithm.shot_encoder.bn
+ shot_ln = cfg.algorithm.shot_encoder.ln
+ shot_droput = cfg.algorithm.shot_encoder.dropout
+ shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim,
+ shot_bn, shot_ln, shot_droput)
+ shot_encoder = shot_encoder.cuda(device)
+ overall_hidden_dims = cfg.algorithm.encoder.hidden_dims
+ rot_bin_num = cfg.algorithm.voting.rot_bin_num
+ overall_bn = cfg.algorithm.encoder.bn
+ overall_ln = cfg.algorithm.encoder.ln
+ overall_dropout = cfg.algorithm.encoder.dropout
+ encoder = create_encoder(tuple_more_num, shot_feature_dim, rgb, overall_hidden_dims, rot_bin_num, joint_num,
+ overall_bn, overall_ln, overall_dropout)
+ encoder = encoder.cuda(device)
+ logger.info("Prepared model.")
+
+ # optimize
+ logger.info("Optimizing...")
+ training_start_time = time.time()
+ lr = cfg.training.lr
+ weight_decay = cfg.training.weight_decay
+ epoch_num = cfg.training.epoch_num
+ lambda_rot = cfg.training.lambda_rot
+ lambda_afford = cfg.training.lambda_afford
+ lambda_conf = cfg.training.lambda_conf
+ voting_num = cfg.algorithm.voting.voting_num
+ angle_tol = cfg.algorithm.voting.angle_tol
+ translation2pc = cfg.algorithm.voting.translation2pc
+ multi_candidate = cfg.algorithm.voting.multi_candidate
+ candidate_threshold = cfg.algorithm.voting.candidate_threshold
+ rotation_multi_neighbor = cfg.algorithm.voting.rotation_multi_neighbor
+ neighbor_threshold = cfg.algorithm.voting.neighbor_threshold
+ rotation_cluster = cfg.algorithm.voting.rotation_cluster
+ bmm_size = cfg.algorithm.voting.bmm_size
+ opt = optim.Adam([*encoder.parameters(), *shot_encoder.parameters()], lr=lr, weight_decay=weight_decay)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epoch_num, eta_min=lr/100.0)
+ scheduler_warmup = GradualWarmupScheduler(opt, multiplier=1, total_epoch=epoch_num//20, after_scheduler=scheduler)
+ tb_writer = SummaryWriter(log_dir=os.path.join(output_dir, 'tb'))
+ iteration = 0
+ for epoch in range(epoch_num):
+ if epoch == 0:
+ opt.zero_grad()
+ opt.step()
+ scheduler_warmup.step()
+
+ loss_meter = AverageMeter()
+ loss_tr_meter = AverageMeter()
+ loss_rot_meter = AverageMeter()
+ loss_afford_meter = AverageMeter()
+ loss_conf_meter = AverageMeter()
+
+ # train
+ shot_encoder.train()
+ encoder.train()
+ logger.info("epoch: " + str(epoch) + " lr: " + str(scheduler_warmup.get_last_lr()[0]))
+ tb_writer.add_scalar('lr', scheduler_warmup.get_last_lr()[0], epoch)
+ with tqdm.tqdm(training_dataloader) as t:
+ data_num = 0
+ data_loader_start_time = time.time()
+ for batch_data in t:
+ if rgb:
+ pcs, pc_normals, pc_shots, pc_colors, target_trs, target_rots, target_affords, target_confs, point_idxs_all = batch_data
+ pcs, pc_normals, pc_shots, pc_colors, target_trs, target_rots, target_affords, target_confs, point_idxs_all = \
+ pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), pc_colors.cuda(device), target_trs.cuda(device), target_rots.cuda(device), target_affords.cuda(device), target_confs.cuda(device), point_idxs_all.cuda(device)
+ else:
+ pcs, pc_normals, pc_shots, target_trs, target_rots, target_affords, target_confs, point_idxs_all = batch_data
+ pcs, pc_normals, pc_shots, target_trs, target_rots, target_affords, target_confs, point_idxs_all = \
+ pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), target_trs.cuda(device), target_rots.cuda(device), target_affords.cuda(device), target_confs.cuda(device), point_idxs_all.cuda(device)
+ # (B, N, 3), (B, N, 3), (B, N, 352)(, (B, N, 3)), (B, J, N_t, 2), (B, J, N_t), (B, J, N_t, 2), (B, J, N_t), (B, N_t, 2 + N_m)
+ B = pcs.shape[0]
+ N = pcs.shape[1]
+ J = target_trs.shape[1]
+ N_t = target_trs.shape[2]
+ data_num += B
+
+ opt.zero_grad()
+ dataloader_end_time = time.time()
+ if cfg.debug:
+ logger.warning("Data loader time: " + str(dataloader_end_time - data_loader_start_time))
+
+ forward_start_time = time.time()
+ # shot encoder for every point
+ shot_feat = shot_encoder(pc_shots) # (B, N, N_s)
+
+ # encoder for sampled point tuples
+ # shot_inputs = torch.cat([shot_feat[point_idxs_all[:, i]] for i in range(0, point_idxs_all.shape[-1])], -1) # (sample_points, feature_dim * (2 + num_more))
+ # normal_inputs = torch.cat([torch.max(torch.sum(normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True),
+ # torch.sum(-normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True))
+ # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, (2+num_more \choose 2))
+ # coord_inputs = torch.cat([pc[point_idxs_all[:, i]] - pc[point_idxs_all[:, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, 3 * (2+num_more \choose 2))
+ # shot_inputs = []
+ # normal_inputs = []
+ # coord_inputs = []
+ # for b in range(pcs.shape[0]):
+ # shot_inputs.append(torch.cat([shot_feat[b][point_idxs_all[b, :, i]] for i in range(0, point_idxs_all.shape[-1])], dim=-1)) # (sample_points, feature_dim * (2 + num_more))
+ # normal_inputs.append(torch.cat([torch.max(torch.sum(normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True),
+ # torch.sum(-normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True))
+ # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, (2+num_more \choose 2))
+ # coord_inputs.append(torch.cat([pcs[b][point_idxs_all[b, :, i]] - pcs[b][point_idxs_all[b, :, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, 3 * (2+num_more \choose 2))
+ # shot_inputs = torch.stack(shot_inputs, dim=0) # (B, sample_points, feature_dim * (2 + num_more))
+ # normal_inputs = torch.stack(normal_inputs, dim=0) # (B, sample_points, (2+num_more \choose 2))
+ # coord_inputs = torch.stack(coord_inputs, dim=0) # (B, sample_points, 3 * (2+num_more \choose 2))
+ shot_inputs = torch.cat([
+ torch.gather(shot_feat, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, shot_feat.shape[-1])))
+ for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, N_s * (2 + N_m))
+ normal_inputs = torch.cat([torch.max(
+ torch.sum(torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True),
+ torch.sum(-torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_normals.shape[-1]))) *
+ torch.gather(pc_normals, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pc_normals.shape[-1]))),
+ dim=-1, keepdim=True))
+ for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, (2+N_m \choose 2))
+ coord_inputs = torch.cat([
+ torch.gather(pcs, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pcs.shape[-1]))) -
+ torch.gather(pcs, 1,
+ point_idxs_all[:, :, j:j+1].expand(
+ (B, N_t, pcs.shape[-1])))
+ for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, 3 * (2+N_m \choose 2))
+ if rgb:
+ rgb_inputs = torch.cat([
+ torch.gather(pc_colors, 1,
+ point_idxs_all[:, :, i:i+1].expand(
+ (B, N_t, pc_colors.shape[-1])))
+ for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, 3 * (2 + N_m))
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1)
+ else:
+ inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1)
+ preds = encoder(inputs) # (B, N_t, (2 + N_r + 2 + 1) * J)
+ forward_end_time = time.time()
+ if cfg.debug:
+ logger.warning("Forward time: " + str(forward_end_time - forward_start_time))
+
+ backward_start_time = time.time()
+ loss = 0
+ # regression loss for translation for topk
+ pred_trs = preds[:, :, 0:(2 * J)] # (B, N_t, 2*J)
+ pred_trs = pred_trs.reshape((B, N_t, J, 2)) # (B, N_t, J, 2)
+ pred_trs = pred_trs.transpose(1, 2) # (B, J, N_t, 2)
+ loss_tr_ = torch.mean((pred_trs - target_trs) ** 2, dim=-1) # (B, J, N_t)
+ loss_tr_ = loss_tr_ * target_confs
+ loss_tr = loss_tr_[loss_tr_ > 0]
+ loss_tr = torch.mean(loss_tr)
+ loss += loss_tr
+ loss_tr_meter.update(loss_tr.item())
+ tb_writer.add_scalar('loss/loss_tr', loss_tr.item(), iteration)
+
+ # classification loss for rotation for topk
+ pred_rots = preds[:, :, (2 * J):(-3 * J)] # (B, N_t, rot_bin_num*J)
+ pred_rots = pred_rots.reshape((B, N_t, J, rot_bin_num)) # (B, N_t, J, rot_bin_num)
+ pred_rots = pred_rots.transpose(1, 2) # (B, J, N_t, rot_bin_num)
+ pred_rots_ = F.log_softmax(pred_rots, dim=-1) # (B, J, N_t, rot_bin_num)
+ target_rots_ = real2prob(target_rots, np.pi, rot_bin_num, circular=False) # (B, J, N_t, rot_bin_num)
+ loss_rot_ = torch.sum(F.kl_div(pred_rots_, target_rots_, reduction='none'), dim=-1) # (B, J, N_t)
+ loss_rot_ = loss_rot_ * target_confs
+ loss_rot = loss_rot_[loss_rot_ > 0]
+ loss_rot = torch.mean(loss_rot)
+ loss_rot *= lambda_rot
+ loss += loss_rot
+ loss_rot_meter.update(loss_rot.item())
+ tb_writer.add_scalar('loss/loss_rot', loss_rot.item(), iteration)
+
+ # regression loss for affordance for topk
+ pred_affords = preds[:, :, (-3 * J):-J] # (B, N_t, 2*J)
+ pred_affords = pred_affords.reshape((B, N_t, J, 2)) # (B, N_t, J, 2)
+ pred_affords = pred_affords.transpose(1, 2) # (B, J, N_t, 2)
+ loss_afford_ = torch.mean((pred_affords - target_affords) ** 2, dim=-1) # (B, J, N_t)
+ loss_afford_ = loss_afford_ * target_confs
+ loss_afford = loss_afford_[loss_afford_ > 0]
+ loss_afford = torch.mean(loss_afford)
+ loss_afford *= lambda_afford
+ loss += loss_afford
+ loss_afford_meter.update(loss_afford.item())
+ tb_writer.add_scalar('loss/loss_afford', loss_afford.item(), iteration)
+
+ # classification loss for goodness
+ pred_confs = preds[:, :, -J:] # (B, N_t, J)
+ pred_confs = pred_confs.transpose(1, 2) # (B, J, N_t)
+ loss_conf = F.binary_cross_entropy_with_logits(pred_confs, target_confs, reduction='none') # (B, J, N_t)
+ loss_conf = torch.mean(loss_conf)
+ loss_conf *= lambda_conf
+ loss += loss_conf
+ loss_conf_meter.update(loss_conf.item())
+ tb_writer.add_scalar('loss/loss_conf', loss_conf.item(), iteration)
+
+ loss.backward(retain_graph=False)
+ # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.)
+ # torch.nn.utils.clip_grad_norm_(shot_encoder.parameters(), 1.)
+ opt.step()
+ backward_end_time = time.time()
+ if cfg.debug:
+ logger.warning("Backward time: " + str(backward_end_time - backward_start_time))
+
+ loss_meter.update(loss.item())
+ tb_writer.add_scalar('loss/loss', loss.item(), iteration)
+
+ t.set_postfix(epoch=epoch, loss=loss_meter.avg, tr=loss_tr_meter.avg, rot=loss_rot_meter.avg, afford=loss_afford_meter.avg, conf=loss_conf_meter.avg)
+
+ iteration += 1
+ data_loader_start_time = time.time()
+ scheduler_warmup.step()
+ tb_writer.add_scalar('loss/loss_tr_avg', loss_tr_meter.avg, epoch)
+ tb_writer.add_scalar('loss/loss_rot_avg', loss_rot_meter.avg, epoch)
+ tb_writer.add_scalar('loss/loss_afford_avg', loss_afford_meter.avg, epoch)
+ tb_writer.add_scalar('loss/loss_conf_avg', loss_conf_meter.avg, epoch)
+ tb_writer.add_scalar('loss/loss_avg', loss_meter.avg, epoch)
+ logger.info("training loss: " + str(loss_tr_meter.avg) + " + " + str(loss_rot_meter.avg) + " + " + \
+ str(loss_afford_meter.avg) + " + " + str(loss_conf_meter.avg) + " = " + str(loss_meter.avg) + ", data num: " + str(data_num))
+
+ # save model
+ if epoch % (epoch_num // 10) == 0:
+ os.makedirs(os.path.join(output_dir, 'weights'), exist_ok=True)
+ torch.save(encoder.state_dict(), os.path.join(output_dir, 'weights', 'encoder_latest.pth'))
+ torch.save(shot_encoder.state_dict(), os.path.join(output_dir, 'weights', 'shot_encoder_latest.pth'))
+
+ # validation
+ if cfg.training.val_training and epoch % (epoch_num // 10) == 0:
+ logger.info("Validating training...")
+ validating_training_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ validating_training_num = cfg.training.val_training_num if cfg.training.val_training_num > 0 else len(testing_training_dataset)
+ validating_training_results = test_fn(testing_training_dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, validating_training_num, device, vis=False)
+ log_metrics(validating_training_results, logger, output_dir, tb_writer, epoch, 'training')
+
+ validating_training_end_time = time.time()
+ logger.info("Validated training.")
+ logger.info("Validating training time: " + str(validating_training_end_time - validating_training_start_time))
+
+ if cfg.training.val_testing and epoch % (epoch_num // 10) == 0:
+ logger.info("Validating testing...")
+ validating_testing_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ validating_testing_num = cfg.training.val_testing_num if cfg.training.val_testing_num > 0 else len(testing_testing_dataset)
+ validating_testing_results = test_fn(testing_testing_dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, validating_testing_num, device, vis=False)
+ log_metrics(validating_testing_results, logger, output_dir, tb_writer, epoch, 'testing')
+
+ validating_testing_end_time = time.time()
+ logger.info("Validated testing.")
+ logger.info("Validating testing time: " + str(validating_testing_end_time - validating_testing_start_time))
+
+ training_end_time = time.time()
+ logger.info("Optimized.")
+ logger.info("Training time: " + str(training_end_time - training_start_time))
+
+ # test
+ if cfg.training.test_train:
+ logger.info("Testing training...")
+ testing_training_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ testing_training_results = test_fn(testing_training_dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, len(testing_training_dataset), device, vis=False)
+ log_metrics(testing_training_results, logger, output_dir, tb_writer, epoch_num, 'training')
+
+ testing_training_end_time = time.time()
+ logger.info("Tested training.")
+ logger.info("Testing training time: " + str(testing_training_end_time - testing_training_start_time))
+
+ if cfg.training.test_test:
+ logger.info("Testing testing...")
+ testing_testing_start_time = time.time()
+ shot_encoder.eval()
+ encoder.eval()
+
+ testing_testing_results = test_fn(testing_testing_dataloader, rgb, shot_encoder, encoder,
+ resolution, voting_num, rot_bin_num, angle_tol,
+ translation2pc, multi_candidate, candidate_threshold, rotation_cluster,
+ rotation_multi_neighbor, neighbor_threshold,
+ bmm_size, len(testing_testing_dataset), device, vis=False)
+ log_metrics(testing_testing_results, logger, output_dir, tb_writer, epoch_num, 'testing')
+
+ testing_testing_end_time = time.time()
+ logger.info("Tested testing.")
+ logger.info("Testing testing time: " + str(testing_testing_end_time - testing_testing_start_time))
+
+ # save model
+ os.makedirs(os.path.join(output_dir, 'weights'), exist_ok=True)
+ torch.save(encoder.state_dict(), os.path.join(output_dir, 'weights', 'encoder_latest.pth'))
+ torch.save(shot_encoder.state_dict(), os.path.join(output_dir, 'weights', 'shot_encoder_latest.pth'))
+
+
+if __name__ == '__main__':
+ train()
diff --git a/utilities/__init__.py b/utilities/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utilities/constants.py b/utilities/constants.py
new file mode 100644
index 0000000..f96637a
--- /dev/null
+++ b/utilities/constants.py
@@ -0,0 +1,15 @@
+import numpy as np
+
+EPS = 1e-3
+
+seed = 42
+
+light_blue_color = np.array([126/255, 208/255, 248/255])
+dark_green_color = np.array([42/255, 157/255, 142/255])
+gray_color = np.array([204/255, 204/255, 204/255])
+dark_gray_color = np.array([179/255, 179/255, 179/255])
+red_color = np.array([232/255, 17/255, 35/255])
+dark_red_color = np.array([111/255, 15/255, 21/255])
+yellow_color = np.array([255/255, 220/255, 126/255])
+
+max_grasp_width = 0.08
diff --git a/utilities/data_utils.py b/utilities/data_utils.py
new file mode 100644
index 0000000..230d256
--- /dev/null
+++ b/utilities/data_utils.py
@@ -0,0 +1,553 @@
+from typing import List, Tuple, Union, Dict, Optional
+import math
+import numpy as np
+from scipy.spatial.transform import Rotation as srot
+from scipy.optimize import least_squares, linear_sum_assignment
+import torch
+import xml.etree.ElementTree as ET
+
+from .metrics_utils import calc_translation_error_batch, calc_direction_error_batch
+
+
+def read_joints_from_urdf_file(urdf_file):
+ tree_urdf = ET.parse(urdf_file)
+ root_urdf = tree_urdf.getroot()
+
+ joint_dict = {}
+ for joint in root_urdf.iter('joint'):
+ joint_name = joint.attrib['name']
+ joint_type = joint.attrib['type']
+ for child in joint.iter('child'):
+ joint_child = child.attrib['link']
+ for parent in joint.iter('parent'):
+ joint_parent = parent.attrib['link']
+ for origin in joint.iter('origin'):
+ if 'xyz' in origin.attrib:
+ joint_xyz = [float(x) for x in origin.attrib['xyz'].split()]
+ else:
+ joint_xyz = [0, 0, 0]
+ if 'rpy' in origin.attrib:
+ joint_rpy = [float(x) for x in origin.attrib['rpy'].split()]
+ else:
+ joint_rpy = [0, 0, 0]
+ if joint_type == 'prismatic' or joint_type == 'revolute' or joint_type == 'continuous':
+ for axis in joint.iter('axis'):
+ joint_axis = [float(x) for x in axis.attrib['xyz'].split()]
+ else:
+ joint_axis = None
+ if joint_type == 'prismatic' or joint_type == 'revolute':
+ for limit in joint.iter('limit'):
+ joint_limit = [float(limit.attrib['lower']), float(limit.attrib['upper'])]
+ else:
+ joint_limit = None
+
+ joint_dict[joint_name] = {
+ 'type': joint_type,
+ 'parent': joint_parent,
+ 'child': joint_child,
+ 'xyz': joint_xyz,
+ 'rpy': joint_rpy,
+ 'axis': joint_axis,
+ 'limit': joint_limit
+ }
+
+ return joint_dict
+
+
+def pc_normalize(pc:np.ndarray, normalize_method:str) -> Tuple[np.ndarray, np.ndarray, float]:
+ if normalize_method == 'none':
+ pc_normalized = pc
+ center = np.array([0., 0., 0.]).astype(pc.dtype)
+ scale = 1.
+ elif normalize_method == 'mean':
+ center = np.mean(pc, axis=0)
+ pc_normalized = pc - center
+ scale = np.max(np.sqrt(np.sum(pc_normalized ** 2, axis=1)))
+ pc_normalized = pc_normalized / scale
+ elif normalize_method == 'bound':
+ center = (np.max(pc, axis=0) + np.min(pc, axis=0)) / 2
+ pc_normalized = pc - center
+ scale = np.max(np.sqrt(np.sum(pc_normalized ** 2, axis=1)))
+ pc_normalized = pc_normalized / scale
+ elif normalize_method == 'median':
+ center = np.median(pc, axis=0)
+ pc_normalized = pc - center
+ scale = np.max(np.sqrt(np.sum(pc_normalized ** 2, axis=1)))
+ pc_normalized = pc_normalized / scale
+ else:
+ raise NotImplementedError
+
+ return (pc_normalized, center, scale)
+
+def joints_normalize(joint_translations:Optional[np.ndarray], joint_rotations:Optional[np.ndarray], center:np.ndarray, scale:float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+ if joint_translations is None and joint_rotations is None:
+ return (None, None)
+
+ J = joint_translations.shape[0] if joint_translations is not None else joint_rotations.shape[0]
+ joint_translations_normalized, joint_rotations_normalized = [], []
+ for j in range(J):
+ if joint_translations is not None:
+ joint_translation_normalized = (joint_translations[j] - center) / scale
+ joint_translations_normalized.append(joint_translation_normalized)
+ if joint_rotations is not None:
+ # joint_axis_normalized = (joint_rotations[j] - center) / scal
+ # joint_rotation_normalized = joint_axis_normalized - joint_translation_normalized
+ # joint_rotation_normalized /= np.linalg.norm(joint_rotation_normalized)
+ joint_rotation_normalized = joint_rotations[j].copy()
+ joint_rotation_normalized /= np.linalg.norm(joint_rotation_normalized)
+ joint_rotations_normalized.append(joint_rotation_normalized)
+ if joint_translations is not None:
+ joint_translations_normalized = np.array(joint_translations_normalized).astype(joint_translations.dtype)
+ else:
+ joint_translations_normalized = None
+ if joint_rotations is not None:
+ joint_rotations_normalized = np.array(joint_rotations_normalized).astype(joint_rotations.dtype)
+ else:
+ joint_rotations_normalized = None
+ return (joint_translations_normalized, joint_rotations_normalized)
+
+def joints_denormalize(joint_translations_normalized:Optional[np.ndarray], joint_rotations_normalized:Optional[np.ndarray], center:np.ndarray, scale:float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+ if joint_translations_normalized is None and joint_rotations_normalized is None:
+ return (None, None)
+
+ J = joint_translations_normalized.shape[0] if joint_translations_normalized is not None else joint_rotations_normalized.shape[0]
+ joint_translations, joint_rotations = [], []
+ for j in range(J):
+ if joint_translations_normalized is not None:
+ joint_translation = joint_translations_normalized[j] * scale + center
+ joint_translations.append(joint_translation)
+ if joint_rotations_normalized is not None:
+ # joint_axis = (joint_translations_normalized[j] + joint_rotations_normalized[j]) * scale + center
+ # joint_rotation = joint_axis - joint_translation
+ joint_rotation = joint_rotations_normalized[j].copy()
+ joint_rotation /= np.linalg.norm(joint_rotation)
+ joint_rotations.append(joint_rotation)
+ if joint_translations_normalized is not None:
+ joint_translations = np.array(joint_translations).astype(joint_translations_normalized.dtype)
+ else:
+ joint_translations = None
+ if joint_rotations_normalized is not None:
+ joint_rotations = np.array(joint_rotations).astype(joint_rotations_normalized.dtype)
+ else:
+ joint_rotations = None
+ return (joint_translations, joint_rotations)
+
+def joint_denormalize(joint_translation_normalized:Optional[np.ndarray], joint_rotation_normalized:Optional[np.ndarray], center:np.ndarray, scale:float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+ if joint_translation_normalized is not None:
+ joint_translation = joint_translation_normalized * scale + center
+ else:
+ joint_translation = None
+ if joint_rotation_normalized is not None:
+ # joint_axis = (joint_translation_normalized + joint_rotation_normalized) * scale + center
+ # joint_rotation = joint_axis - joint_translation
+ joint_rotation = joint_rotation_normalized.copy()
+ joint_rotation /= np.linalg.norm(joint_rotation)
+ else:
+ joint_rotation = None
+ return (joint_translation, joint_rotation)
+
+def transform_pc(pc_camera:np.ndarray, c2w:np.ndarray) -> np.ndarray:
+ # pc_camera: (N, 3), c2w: (4, 4)
+ pc_camera_hm = np.concatenate([pc_camera, np.ones((pc_camera.shape[0], 1), dtype=pc_camera.dtype)], axis=-1) # (N, 4)
+ pc_world_hm = pc_camera_hm @ c2w.T # (N, 4)
+ pc_world = pc_world_hm[:, :3] # (N, 3)
+ return pc_world
+
+def transform_dir(dir_camera:np.ndarray, c2w:np.ndarray) -> np.ndarray:
+ # dir_camera: (N, 3), c2w: (4, 4)
+ dir_camera_hm = np.concatenate([dir_camera, np.zeros((dir_camera.shape[0], 1), dtype=dir_camera.dtype)], axis=-1) # (N, 4)
+ dir_world_hm = dir_camera_hm @ c2w.T # (N, 4)
+ dir_world = dir_world_hm[:, :3] # (N, 3)
+ return dir_world
+
+
+def generate_target_tr(pc:np.ndarray, o:np.ndarray, point_idxs:np.ndarray) -> np.ndarray:
+ a = pc[point_idxs[:, 0]] # (N_t, 3)
+ b = pc[point_idxs[:, 1]] # (N_t, 3)
+ pdist = a - b
+ pdist_unit = pdist / (np.linalg.norm(pdist, axis=-1, keepdims=True) + 1e-7)
+ proj_len = np.sum((a - o) * pdist_unit, -1)
+ oc = a - o - proj_len[..., None] * pdist_unit
+ dist2o = np.linalg.norm(oc, axis=-1)
+ target_tr = np.stack([proj_len, dist2o], -1)
+ return target_tr.astype(np.float32).reshape((-1, 2))
+
+def generate_target_rot(pc:np.ndarray, axis:np.ndarray, point_idxs:np.ndarray) -> np.ndarray:
+ a = pc[point_idxs[:, 0]] # (N_t, 3)
+ b = pc[point_idxs[:, 1]] # (N_t, 3)
+ pdist = a - b
+ pdist_unit = pdist / (np.linalg.norm(pdist, axis=-1, keepdims=True) + 1e-7)
+ cos = np.sum(pdist_unit * axis, axis=-1)
+ cos = np.clip(cos, -1., 1.)
+ target_rot = np.arccos(cos)
+ return target_rot.astype(np.float32).reshape((-1,))
+
+
+def farthest_point_sample(point:np.ndarray, npoint:int) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Input:
+ xyz: pointcloud data, [N, D]
+ npoint: number of samples
+ Return:
+ point: sampled pointcloud, [npoint, D]
+ centroids: sampled pointcloud index
+ """
+ N, D = point.shape
+ xyz = point[:,:3]
+ centroids = np.zeros((npoint,))
+ distance = np.ones((N,)) * 1e10
+ farthest = np.random.randint(0, N)
+ for i in range(npoint):
+ centroids[i] = farthest
+ centroid = xyz[farthest, :]
+ dist = np.sum((xyz - centroid) ** 2, -1)
+ mask = dist < distance
+ distance[mask] = dist[mask]
+ farthest = np.argmax(distance, -1)
+ centroids = centroids.astype(np.int32)
+ point = point[centroids]
+ return (point, centroids)
+
+
+def fibonacci_sphere(samples:int) -> List[Tuple[float, float, float]]:
+ points = []
+ phi = math.pi * (3. - math.sqrt(5.)) # golden angle in radians
+
+ for i in range(samples):
+ y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1
+ radius = math.sqrt(1 - y * y) # radius at y
+
+ theta = phi * i # golden angle increment
+
+ x = math.cos(theta) * radius
+ z = math.sin(theta) * radius
+
+ points.append((x, y, z))
+
+ return points
+
+
+def real2prob(val:Union[torch.Tensor, np.ndarray], max_val:float, num_bins:int, circular:bool=False) -> Union[torch.Tensor, np.ndarray]:
+ is_torch = isinstance(val, torch.Tensor)
+ if is_torch:
+ res = torch.zeros((*val.shape, num_bins), dtype=val.dtype).to(val.device)
+ else:
+ res = np.zeros((*val.shape, num_bins), dtype=val.dtype)
+
+ if not circular:
+ interval = max_val / (num_bins - 1)
+ if is_torch:
+ low = torch.clamp(torch.floor(val / interval).long(), max=num_bins - 2)
+ else:
+ low = np.clip(np.floor(val / interval).astype(np.int64), a_min=None, a_max=num_bins - 2)
+ high = low + 1
+ # assert torch.all(low >= 0) and torch.all(high < num_bins)
+
+ # huge memory
+ if is_torch:
+ res.scatter_(-1, low[..., None], torch.unsqueeze(1. - (val / interval - low), -1))
+ res.scatter_(-1, high[..., None], 1. - torch.gather(res, -1, low[..., None]))
+ else:
+ np.put_along_axis(res, low[..., None], np.expand_dims(1. - (val / interval - low), -1), -1)
+ np.put_along_axis(res, high[..., None], 1. - np.take_along_axis(res, low[..., None], -1), -1)
+ # res[..., low] = 1. - (val / interval - low)
+ # res[..., high] = 1. - res[..., low]
+ # assert torch.all(0 <= res[..., low]) and torch.all(1 >= res[..., low])
+ return res
+ else:
+ interval = max_val / num_bins
+ if is_torch:
+ val_new = torch.clone(val)
+ else:
+ val_new = val.copy()
+ val_new[val < interval / 2] += max_val
+ res = real2prob(val_new - interval / 2, max_val, num_bins + 1)
+ res[..., 0] += res[..., -1]
+ return res[..., :-1]
+
+
+def pc_ncs(pc:np.ndarray, bbox_min:np.ndarray, bbox_max:np.ndarray) -> np.ndarray:
+ return (pc - (bbox_min + bbox_max) / 2 + 0.5 * (bbox_max - bbox_min)) / (bbox_max - bbox_min)
+
+def joints_ncs(joint_translations:np.ndarray, joint_rotations:Optional[np.ndarray], bbox_min:np.ndarray, bbox_max:np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ return ((joint_translations - (bbox_min + bbox_max) / 2 + 0.5 * (bbox_max - bbox_min)) / (bbox_max - bbox_min), joint_rotations.copy() if joint_rotations is not None else None)
+
+
+def rotate_points_with_rotvec(points, rot_vecs):
+ """Rotate points by given rotation vectors.
+
+ Rodrigues' rotation formula is used.
+ """
+ theta = np.linalg.norm(rot_vecs, axis=1)[:, np.newaxis]
+ with np.errstate(invalid='ignore'):
+ v = rot_vecs / theta
+ v = np.nan_to_num(v)
+ dot = np.sum(points * v, axis=1)[:, np.newaxis]
+ cos_theta = np.cos(theta)
+ sin_theta = np.sin(theta)
+
+ return cos_theta * points + sin_theta * np.cross(v, points) + dot * (1 - cos_theta) * v
+
+
+def scale_pts(source, target):
+ # compute scaling factor between source: [N x 3], target: [N x 3]
+ pdist_s = source.reshape(source.shape[0], 1, 3) - source.reshape(1, source.shape[0], 3)
+ A = np.sqrt(np.sum(pdist_s**2, 2)).reshape(-1)
+ pdist_t = target.reshape(target.shape[0], 1, 3) - target.reshape(1, target.shape[0], 3)
+ b = np.sqrt(np.sum(pdist_t**2, 2)).reshape(-1)
+ scale = np.dot(A, b) / (np.dot(A, A)+1e-6)
+ return scale
+
+def rotate_pts(source, target):
+ # compute rotation between source: [N x 3], target: [N x 3]
+ # pre-centering
+ source = source - np.mean(source, 0, keepdims=True)
+ target = target - np.mean(target, 0, keepdims=True)
+ M = np.matmul(target.T, source)
+ U, D, Vh = np.linalg.svd(M, full_matrices=True)
+ d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0
+ if d:
+ D[-1] = -D[-1]
+ U[:, -1] = -U[:, -1]
+ R = np.matmul(U, Vh)
+ return R
+
+
+def objective_eval_t(params, x0, y0, x1, y1, joints, isweight=True):
+ # params: [:3] R0, [3:] R1
+ # x0: N x 3, y0: N x 3, x1: M x 3, y1: M x 3, R0: 1 x 3, R1: 1 x 3, joints: K x 3
+ rotvec0 = params[:3].reshape((1,3))
+ rotvec1 = params[3:].reshape((1,3))
+ res0 = y0 - rotate_points_with_rotvec(x0, rotvec0)
+ res1 = y1 - rotate_points_with_rotvec(x1, rotvec1)
+ res_R = rotvec0 - rotvec1
+ if isweight:
+ res0 /= x0.shape[0]
+ res1 /= x1.shape[0]
+ return np.concatenate((res0, res1, res_R), 0).ravel()
+
+def objective_eval_r(params, x0, y0, x1, y1, joints, isweight=True):
+ # params: [:3] R0, [3:] R1
+ # x0: N x 3, y0: N x 3, x1: M x 3, y1: M x 3, R0: 1 x 3, R1: 1 x 3, joints: K x 3
+ rotvec0 = params[:3].reshape((1,3))
+ rotvec1 = params[3:].reshape((1,3))
+ res0 = y0 - rotate_points_with_rotvec(x0, rotvec0)
+ res1 = y1 - rotate_points_with_rotvec(x1, rotvec1)
+ res_joint = rotate_points_with_rotvec(joints, rotvec0) - rotate_points_with_rotvec(joints, rotvec1)
+ if isweight:
+ res0 /= x0.shape[0]
+ res1 /= x1.shape[0]
+ res_joint /= joints.shape[0]
+ return np.concatenate((res0, res1, res_joint), 0).ravel()
+
+
+def joint_transformation_estimator(dataset:Dict[str, np.ndarray], joint_type:str, best_inliers:Optional[Tuple[np.ndarray, np.ndarray]]=None) -> Optional[Dict[str, np.ndarray]]:
+ nsource0 = dataset['source0'].shape[0]
+ nsource1 = dataset['source1'].shape[0]
+ if nsource0 < 3 or nsource1 < 3:
+ return None
+ if best_inliers is None:
+ sample_idx0 = np.random.randint(nsource0, size=3)
+ sample_idx1 = np.random.randint(nsource1, size=3)
+ else:
+ sample_idx0 = best_inliers[0]
+ sample_idx1 = best_inliers[1]
+
+ source0 = dataset['source0'][sample_idx0, :]
+ target0 = dataset['target0'][sample_idx0, :]
+ source1 = dataset['source1'][sample_idx1, :]
+ target1 = dataset['target1'][sample_idx1, :]
+
+ scale0 = scale_pts(source0, target0)
+ scale1 = scale_pts(source1, target1)
+ scale0_inv = scale_pts(target0, source0)
+ scale1_inv = scale_pts(target1, source1)
+
+ target0_scaled_centered = scale0_inv*target0
+ target0_scaled_centered -= np.mean(target0_scaled_centered, 0, keepdims=True)
+ source0_centered = source0 - np.mean(source0, 0, keepdims=True)
+
+ target1_scaled_centered = scale1_inv*target1
+ target1_scaled_centered -= np.mean(target1_scaled_centered, 0, keepdims=True)
+ source1_centered = source1 - np.mean(source1, 0, keepdims=True)
+
+ joint_points0 = np.ones_like(np.linspace(0, 1, num = np.min((source0.shape[0], source1.shape[0]))+1 )[1:].reshape((-1, 1)))*dataset['joint_direction'].reshape((1, 3))
+
+ R0 = rotate_pts(source0_centered, target0_scaled_centered)
+ R1 = rotate_pts(source1_centered, target1_scaled_centered)
+
+ rotvec0 = srot.from_matrix(R0).as_rotvec()
+ rotvec1 = srot.from_matrix(R1).as_rotvec()
+ if joint_type == 'prismatic':
+ res = least_squares(objective_eval_t, np.hstack((rotvec0, rotvec1)), verbose=0, ftol=1e-4, method='lm',
+ args=(source0_centered, target0_scaled_centered, source1_centered, target1_scaled_centered, joint_points0, False))
+ elif joint_type == 'revolute':
+ res = least_squares(objective_eval_r, np.hstack((rotvec0, rotvec1)), verbose=0, ftol=1e-4, method='lm',
+ args=(source0_centered, target0_scaled_centered, source1_centered, target1_scaled_centered, joint_points0, False))
+ else:
+ raise ValueError
+ R0 = srot.from_rotvec(res.x[:3]).as_matrix()
+ R1 = srot.from_rotvec(res.x[3:]).as_matrix()
+
+ translation0 = np.mean(target0.T-scale0*np.matmul(R0, source0.T), 1)
+ translation1 = np.mean(target1.T-scale1*np.matmul(R1, source1.T), 1)
+
+ if np.isnan(translation0).any() or np.isnan(translation1).any() or np.isnan(R0).any() or np.isnan(R0).any():
+ return None
+
+ jtrans = dict()
+ jtrans['rotation0'] = R0
+ jtrans['scale0'] = scale0
+ jtrans['translation0'] = translation0
+ jtrans['rotation1'] = R1
+ jtrans['scale1'] = scale1
+ jtrans['translation1'] = translation1
+ return jtrans
+
+def joint_transformation_verifier(dataset:Dict[str, np.ndarray], model:Dict[str, np.ndarray], inlier_th:float) -> Tuple[float, Tuple[np.ndarray, np.ndarray]]:
+ res0 = dataset['target0'].T - model['scale0'] * np.matmul( model['rotation0'], dataset['source0'].T ) - model['translation0'].reshape((3, 1))
+ inliers0 = np.sqrt(np.sum(res0**2, 0)) < inlier_th
+ res1 = dataset['target1'].T - model['scale1'] * np.matmul( model['rotation1'], dataset['source1'].T ) - model['translation1'].reshape((3, 1))
+ inliers1 = np.sqrt(np.sum(res1**2, 0)) < inlier_th
+ score = ( np.sum(inliers0)/res0.shape[0] + np.sum(inliers1)/res1.shape[0] ) / 2
+ return (score, (inliers0, inliers1))
+
+def ransac(dataset:Dict[str, np.ndarray], inlier_threshold:float, iteration_num:int, joint_type:str) -> Optional[Dict[str, np.ndarray]]:
+ best_model = None
+ best_score = -np.inf
+ best_inliers = None
+ for i in range(iteration_num):
+ cur_model = joint_transformation_estimator(dataset, joint_type=joint_type, best_inliers=None)
+ if cur_model is None:
+ return None
+ cur_score, cur_inliers = joint_transformation_verifier(dataset, cur_model, inlier_threshold)
+ if cur_score > best_score:
+ best_model = cur_model
+ best_inliers = cur_inliers
+ best_score = cur_score
+ best_model = joint_transformation_estimator(dataset, joint_type=joint_type, best_inliers=best_inliers)
+ return best_model
+
+
+def match_joints(proposal_joints:List[Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]],
+ gt_joint_translations:np.ndarray, gt_joint_rotations:np.ndarray, match_metric:str,
+ has_affordance:bool=False) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
+ # TODO: affordance-related not very well
+ proposal_num = len(proposal_joints) # N
+ gt_num = gt_joint_translations.shape[0] # M
+ if proposal_num == 0:
+ pred_joint_translations = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype)
+ pred_joint_rotations = np.ones_like(gt_joint_rotations, dtype=gt_joint_rotations.dtype)
+ pred_joint_rotations = pred_joint_rotations / np.linalg.norm(pred_joint_rotations, axis=-1, keepdims=True)
+ if has_affordance:
+ pred_affordances = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype)
+ return (pred_joint_translations, pred_joint_rotations, pred_affordances)
+ else:
+ return (pred_joint_translations, pred_joint_rotations)
+
+ proposal_joint_translations = np.array([proposal_joints[i][0] for i in range(proposal_num)]) # (N, 3)
+ proposal_joint_rotations = np.array([proposal_joints[i][1] for i in range(proposal_num)]) # (N, 3)
+ if has_affordance:
+ proposal_affordances = np.array([proposal_joints[i][2] for i in range(proposal_num)]) # (N, 3)
+ else:
+ pass
+ cost_matrix = np.zeros((gt_num, proposal_num)) # (M, N)
+ for gt_idx in range(gt_num):
+ gt_joint_translation = gt_joint_translations[gt_idx].reshape((1, 3)).repeat(proposal_num, axis=0) # (N, 3)
+ gt_joint_rotation = gt_joint_rotations[gt_idx].reshape((1, 3)).repeat(proposal_num, axis=0) # (N, 3)
+ translation_errors = calc_translation_error_batch(proposal_joint_translations, gt_joint_translation, proposal_joint_rotations, gt_joint_rotation) # (N,)
+ direction_errors = calc_direction_error_batch(proposal_joint_rotations, gt_joint_rotation) # (N,)
+ if match_metric == 'tr_dist':
+ cost_matrix[gt_idx] = translation_errors[0]
+ elif match_metric == 'tr_along':
+ cost_matrix[gt_idx] = translation_errors[1]
+ elif match_metric == 'tr_perp':
+ cost_matrix[gt_idx] = translation_errors[2]
+ elif match_metric == 'tr_plane':
+ cost_matrix[gt_idx] = translation_errors[3]
+ elif match_metric == 'tr_line':
+ cost_matrix[gt_idx] = translation_errors[4]
+ elif match_metric == 'tr_mean':
+ cost_matrix[gt_idx] = (translation_errors[0] + translation_errors[1] + translation_errors[2] + translation_errors[3] + translation_errors[4]) / 5
+ elif match_metric == 'dir':
+ cost_matrix[gt_idx] = direction_errors
+ elif match_metric == 'tr_dist_dir':
+ cost_matrix[gt_idx] = (translation_errors[0] + direction_errors) / 2
+ elif match_metric == 'tr_line_dir':
+ cost_matrix[gt_idx] = (translation_errors[4] + direction_errors) / 2
+ elif match_metric == 'tr_mean_dir':
+ cost_matrix[gt_idx] = ((translation_errors[0] + translation_errors[1] + translation_errors[2] + translation_errors[3] + translation_errors[4]) / 5 + direction_errors) / 2
+ else:
+ raise NotImplementedError
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
+
+ pred_joint_translations = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype)
+ pred_joint_rotations = np.ones_like(gt_joint_rotations, dtype=gt_joint_rotations.dtype)
+ if has_affordance:
+ pred_affordances = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype)
+ else:
+ pass
+ for gt_idx, proposal_idx in zip(row_ind, col_ind):
+ pred_joint_translations[gt_idx] = proposal_joint_translations[proposal_idx]
+ pred_joint_rotations[gt_idx] = proposal_joint_rotations[proposal_idx]
+ if has_affordance:
+ pred_affordances[gt_idx] = proposal_affordances[proposal_idx]
+ else:
+ pass
+ pred_joint_rotations = pred_joint_rotations / np.linalg.norm(pred_joint_rotations, axis=-1, keepdims=True)
+ if has_affordance:
+ return (pred_joint_translations, pred_joint_rotations, pred_affordances)
+ else:
+ return (pred_joint_translations, pred_joint_rotations)
+
+
+def pc_noise(pc:np.ndarray, distortion_rate:float, distortion_level:float,
+ outlier_rate:float, outlier_level:float) -> np.ndarray:
+ num_points = pc.shape[0]
+ pc_center = (np.max(pc, axis=0) + np.min(pc, axis=0)) / 2
+ pc_scale = np.max(pc, axis=0) - np.min(pc, axis=0)
+ pc_noised = pc.copy()
+ # distortion noise
+ distortion_indices = np.random.choice(num_points, int(distortion_rate * num_points), replace=False)
+ distortion_noise = np.random.normal(0.0, distortion_level * pc_scale, pc_noised[distortion_indices].shape)
+ pc_noised[distortion_indices] = pc_noised[distortion_indices] + distortion_noise
+ # outlier noise
+ outlier_indices = np.random.choice(num_points, int(outlier_rate * num_points), replace=False)
+ outlier_noise = np.random.uniform(pc_center - outlier_level * pc_scale, pc_center + outlier_level * pc_scale, pc_noised[outlier_indices].shape)
+ pc_noised[outlier_indices] = outlier_noise
+ # print(num_points, int(distortion_rate * num_points), int(outlier_rate * num_points))
+ return pc_noised
+
+
+if __name__ == '__main__':
+ # angle_tol = 1.5
+ angle_tol = 0.35
+ rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi))
+ print(rot_candidate_num)
+ sphere_pts = fibonacci_sphere(rot_candidate_num) # (N, 3)
+ sphere_pts = np.array(sphere_pts)
+
+ # figure out the angles between neighboring 8 points on the sphere
+ min_angles = []
+ max_angles = []
+ mean_angles = []
+ for i in range(len(sphere_pts)):
+ a = sphere_pts[i] # (3,)
+ cos = np.dot(a, sphere_pts.T) # (N,)
+ cos = np.clip(cos, -1., 1.)
+ theta = np.arccos(cos) # (N,)
+ idxs = np.argsort(theta)[:9] # (9,)
+ thetas = theta[idxs[1:]] / np.pi * 180 # (8,)
+ min_angles.append(np.min(thetas))
+ max_angles.append(np.max(thetas))
+ mean_angles.append(np.mean(thetas))
+
+ print(np.min(min_angles))
+ print(np.max(min_angles))
+ print(np.mean(min_angles))
+ print(np.min(max_angles))
+ print(np.max(max_angles))
+ print(np.mean(max_angles))
+ print(np.min(mean_angles))
+ print(np.max(mean_angles))
+ print(np.mean(mean_angles))
diff --git a/utilities/env_utils.py b/utilities/env_utils.py
new file mode 100644
index 0000000..eb91e50
--- /dev/null
+++ b/utilities/env_utils.py
@@ -0,0 +1,13 @@
+import os
+import random
+import numpy as np
+import torch
+
+def setup_seed(seed:int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
diff --git a/utilities/metrics_utils.py b/utilities/metrics_utils.py
new file mode 100644
index 0000000..04a6742
--- /dev/null
+++ b/utilities/metrics_utils.py
@@ -0,0 +1,651 @@
+from typing import Tuple, List, Union, Dict, Optional
+import os
+import numpy as np
+import torch
+import torch.nn.functional as F
+from logging import Logger
+from torch.utils.tensorboard import SummaryWriter
+import pandas as pd
+
+
+def calc_translation_error(pred_p:np.ndarray, gt_p:np.ndarray, pred_e:Optional[np.ndarray], gt_e:Optional[np.ndarray]) -> Tuple[float, float, float, float, float]:
+ def calc_plane_error(pred_translation:np.ndarray, pred_direction:np.ndarray,
+ gt_translation:np.ndarray, gt_direction:np.ndarray) -> float:
+ if abs(np.dot(pred_direction, gt_direction)) < 1e-3:
+ # parallel to the plane
+ # point-to-line distance
+ dist = np.linalg.norm(np.cross(pred_direction, gt_translation - pred_translation))
+ return dist
+ # gt_direction \dot (x - gt_translation) = 0
+ # x = pred_translation + t * pred_direction
+ t = np.dot(gt_translation - pred_translation, gt_direction) / np.dot(pred_direction, gt_direction)
+ x = pred_translation + t * pred_direction
+ dist = np.linalg.norm(x - gt_translation)
+ return dist
+ def calc_line_error(pred_translation:np.ndarray, pred_direction:np.ndarray,
+ gt_translation:np.ndarray, gt_direction:np.ndarray) -> float:
+ orth_vect = np.cross(gt_direction, pred_direction)
+ p = gt_translation - pred_translation
+ if np.linalg.norm(orth_vect) < 1e-3:
+ dist = np.linalg.norm(np.cross(p, gt_direction)) / np.linalg.norm(gt_direction)
+ else:
+ dist = abs(np.dot(orth_vect, p)) / np.linalg.norm(orth_vect)
+ return dist
+ distance_error = np.linalg.norm(pred_p - gt_p) * 100.0
+ if pred_e is None or gt_e is None:
+ along_error = 0.0
+ perp_error = 0.0
+ plane_error = 0.0
+ line_error = 0.0
+ else:
+ along_error = abs(np.dot(pred_p - gt_p, gt_e)) * 100.0
+ perp_error = np.sqrt(distance_error**2 - along_error**2)
+ plane_error = calc_plane_error(pred_p, pred_e, gt_p, gt_e) * 100.0
+ line_error = calc_line_error(pred_p, pred_e, gt_p, gt_e) * 100.0
+
+ return (distance_error, along_error, perp_error, plane_error, line_error)
+
+def calc_translation_error_batch(pred_ps:np.ndarray, gt_ps:np.ndarray, pred_es:Optional[np.ndarray], gt_es:Optional[np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ def calc_plane_error_batch(pred_translations:np.ndarray, pred_directions:np.ndarray,
+ gt_translations:np.ndarray, gt_directions:np.ndarray) -> np.ndarray:
+ dists = np.zeros(pred_translations.shape[:-1], dtype=pred_translations.dtype)
+ flag = np.abs(np.sum(pred_directions * gt_directions, axis=-1)) < 1e-3
+ not_flag = np.logical_not(flag)
+
+ dists[flag] = np.linalg.norm(np.cross(gt_translations[flag] - pred_translations[flag], pred_translations[flag]), axis=-1)
+
+ ts = np.sum((gt_translations[not_flag] - pred_translations[not_flag]) * gt_directions[not_flag], axis=-1) / np.sum(pred_directions[not_flag] * gt_directions[not_flag], axis=-1)
+ xs = pred_translations[not_flag] + ts[..., None] * pred_directions[not_flag]
+ dists[not_flag] = np.linalg.norm(xs - gt_translations[not_flag], axis=-1)
+ return dists
+ def calc_line_error_batch(pred_translations:np.ndarray, pred_directions:np.ndarray,
+ gt_translations:np.ndarray, gt_directions:np.ndarray) -> np.ndarray:
+ dists = np.zeros(pred_translations.shape[:-1], dtype=pred_translations.dtype)
+ orth_vects = np.cross(gt_directions, pred_directions)
+ ps = gt_translations - pred_translations
+ flag = np.linalg.norm(orth_vects, axis=-1) < 1e-3
+ not_flag = np.logical_not(flag)
+
+ dists[flag] = np.linalg.norm(np.cross(ps[flag], gt_directions[flag]), axis=-1) / np.linalg.norm(gt_directions[flag], axis=-1)
+
+ dists[not_flag] = np.abs(np.sum(orth_vects[not_flag] * ps[not_flag], axis=-1)) / np.linalg.norm(orth_vects[not_flag], axis=-1)
+ return dists
+ distance_errors = np.linalg.norm(pred_ps - gt_ps, axis=-1) * 100.0
+ if pred_es is None or gt_es is None:
+ along_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype)
+ perp_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype)
+ plane_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype)
+ line_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype)
+ else:
+ along_errors = np.abs(np.sum((pred_ps - gt_ps) * gt_es, axis=-1)) * 100.0
+ perp_errors = np.sqrt(distance_errors**2 - along_errors**2)
+ plane_errors = calc_plane_error_batch(pred_ps, pred_es, gt_ps, gt_es) * 100.0
+ line_errors = calc_line_error_batch(pred_ps, pred_es, gt_ps, gt_es) * 100.0
+
+ return (distance_errors, along_errors, perp_errors, plane_errors, line_errors)
+
+def calc_direction_error(pred_e:np.ndarray, gt_e:np.ndarray) -> float:
+ cos_theta = np.dot(pred_e, gt_e)
+ cos_theta = np.clip(cos_theta, -1., 1.)
+ angle_radian = np.arccos(cos_theta) * 180.0 / np.pi
+ return angle_radian
+
+def calc_direction_error_batch(pred_es:np.ndarray, gt_es:np.ndarray) -> np.ndarray:
+ cos_thetas = np.sum(pred_es * gt_es, axis=-1)
+ cos_thetas = np.clip(cos_thetas, -1., 1.)
+ angle_radians = np.arccos(cos_thetas) * 180.0 / np.pi
+ return angle_radians
+
+
+def calculate_miou_loss(pred_seg:torch.Tensor, gt_seg_onehot:torch.Tensor) -> torch.Tensor:
+ # pred_seg: (B, N, C), gt_seg_onehot: (B, N, C)
+ dot = torch.sum(pred_seg * gt_seg_onehot, dim=-2)
+ denominator = torch.sum(pred_seg, dim=-2) + torch.sum(gt_seg_onehot, dim=-2) - dot
+ mIoU = dot / (denominator + 1e-7) # (B, C)
+ return torch.mean(1.0 - mIoU)
+
+def calculate_ncs_loss(pred_ncs:torch.Tensor, gt_ncs:torch.Tensor, gt_seg_onehot:torch.Tensor) -> torch.Tensor:
+ # pred_ncs: (B, N, 3 * C), gt_ncs: (B, N, 3), gt_seg_onehot: (B, N, C)
+ loss_ncs = 0
+ ncs_splits = torch.split(pred_ncs, split_size_or_sections=3, dim=-1) # (C,), (B, N, 3)
+ seg_splits = torch.split(gt_seg_onehot, split_size_or_sections=1, dim=2) # (C,), (B, N, 1)
+ C = len(ncs_splits)
+ for i in range(C):
+ diff_l2 = torch.norm(ncs_splits[i] - gt_ncs, dim=-1) # (B, N)
+ loss_ncs += torch.mean(seg_splits[i][:, :, 0] * diff_l2, dim=-1) # (B,)
+ return torch.mean(loss_ncs)
+
+def calculate_heatmap_loss(pred_heatmap:torch.Tensor, gt_heatmap:torch.Tensor, mask:torch.Tensor) -> torch.Tensor:
+ # pred_heatmap: (B, N), gt_heatmap: (B, N), mask: (B, N)
+ loss_heatmap = torch.abs(pred_heatmap - gt_heatmap)[mask > 0]
+ loss_heatmap = torch.mean(loss_heatmap)
+ loss_heatmap[torch.isnan(loss_heatmap)] = 0.0
+ return loss_heatmap
+
+def calculate_unitvec_loss(pred_unitvec:torch.Tensor, gt_unitvec:torch.Tensor, mask:torch.Tensor) -> torch.Tensor:
+ # pred_unitvec: (B, N, 3), gt_unitvec: (B, N, 3), mask: (B, N)
+ loss_unitvec = torch.norm(pred_unitvec - gt_unitvec, dim=-1)[mask > 0]
+ return torch.mean(loss_unitvec)
+ # pt_diff = pred_unitvec - gt_unitvec
+ # pt_dist = torch.sum(pt_diff.abs(), dim=-1) # (B, N)
+ # loss_pt_dist = torch.mean(pt_dist[mask > 0])
+ # pred_unitvec_normalized = pred_unitvec / (torch.norm(pred_unitvec, dim=-1, keepdim=True) + 1e-7)
+ # gt_unitvec_normalized = gt_unitvec / (torch.norm(gt_unitvec, dim=-1, keepdim=True) + 1e-7)
+ # dir_diff = torch.sum(-(pred_unitvec_normalized * gt_unitvec_normalized), dim=-1) # (B, N)
+ # loss_dir_diff = torch.mean(dir_diff[mask > 0])
+ # loss_unitvec = loss_pt_dist + loss_dir_diff
+ # return loss_unitvec
+
+
+def focal_loss(inputs:torch.Tensor, targets:torch.Tensor, alpha:Optional[torch.Tensor]=None, gamma:float=2.0, ignore_index:Optional[int]=None) -> torch.Tensor:
+ # inputs: (N, C), targets: (N,)
+ if ignore_index is not None:
+ valid_mask = targets != ignore_index
+ targets = targets[valid_mask]
+
+ if targets.shape[0] == 0:
+ return torch.tensor(0.0).to(dtype=inputs.dtype, device=inputs.device)
+
+ inputs = inputs[valid_mask]
+
+ log_p = torch.clamp(torch.log(inputs + 1e-7), max=0.0)
+ if ignore_index is not None:
+ ce_loss = F.nll_loss(
+ log_p, targets, weight=alpha, ignore_index=ignore_index, reduction="none"
+ )
+ else:
+ ce_loss = F.nll_loss(
+ log_p, targets, weight=alpha, reduction="none"
+ )
+ log_p_t = log_p.gather(1, targets[:, None]).squeeze(-1)
+ loss = ce_loss * ((1 - log_p_t.exp()) ** gamma)
+ loss = loss.mean()
+ return loss
+
+def dice_loss(input:torch.Tensor, target:torch.Tensor) -> torch.Tensor:
+ def one_hot(labels:torch.Tensor, num_classes:int, device:Optional[torch.device]=None, dtype:Optional[torch.dtype]=None) -> torch.Tensor:
+ if not isinstance(labels, torch.Tensor):
+ raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")
+
+ if not labels.dtype == torch.int64:
+ raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")
+
+ if num_classes < 1:
+ raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes))
+
+ shape = labels.shape
+ one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
+ return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + 1e-6
+ # input: (N, C, H, W), target: (N, H, W)
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}")
+
+ if not input.shape[-2:] == target.shape[-2:]:
+ raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {target.shape}")
+
+ if not input.device == target.device:
+ raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")
+
+ # create the labels one hot tensor
+ target_one_hot = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype)
+
+ # compute the actual dice score
+ dims = (1, 2, 3)
+ intersection = torch.sum(input * target_one_hot, dims)
+ cardinality = torch.sum(input + target_one_hot, dims)
+
+ dice_score = 2.0 * intersection / (cardinality + 1e-8)
+
+ return torch.mean(-dice_score + 1.0)
+
+
+def calculate_offset_loss(pred_offsets:torch.Tensor, gt_offsets:torch.Tensor, function_masks:torch.Tensor) -> torch.Tensor:
+ # pred_offsets: (B, N, 3), gt_offsets: (B, N, 3), function_masks: (B, N)
+ pt_diff = pred_offsets - gt_offsets
+ pt_dist = torch.sum(pt_diff.abs(), dim=-1) # (B, N)
+ loss_pt_offset_dist = pt_dist[function_masks < 2].mean()
+ loss_pt_offset_dist[torch.isnan(loss_pt_offset_dist)] = 0.0
+ gt_offsets_norm = torch.norm(gt_offsets, dim=-1, keepdim=True)
+ gt_offsets_normalized = gt_offsets / (gt_offsets_norm + 1e-8)
+ pred_offsets_norm = torch.norm(pred_offsets, dim=-1, keepdim=True)
+ pred_offsets_normalized = pred_offsets / (pred_offsets_norm + 1e-8)
+ dir_diff = -(gt_offsets_normalized * pred_offsets_normalized).sum(dim=-1) # (B, N)
+ loss_offset_dir = dir_diff[function_masks < 2].mean()
+ loss_offset_dir[torch.isnan(loss_offset_dir)] = 0.0
+ loss_offset = loss_offset_dir + loss_pt_offset_dist
+ return loss_offset
+
+def calculate_dir_loss(pred_dirs:torch.Tensor, gt_dirs:torch.Tensor, function_masks:torch.Tensor) -> torch.Tensor:
+ # pred_dirs: (B, N, 3), gt_dirs: (B, N, 3), function_masks: (B, N)
+ pt_diff = pred_dirs - gt_dirs
+ pt_dist = torch.sum(pt_diff.abs(), dim=-1) # (B, N)
+ dis_loss = pt_dist[function_masks < 2].mean()
+ dis_loss[torch.isnan(dis_loss)] = 0.0
+ dir_loss = -(pred_dirs * gt_dirs).sum(dim=-1) # (B, N)
+ dir_loss = dir_loss[function_masks < 2].mean()
+ dir_loss[torch.isnan(dir_loss)] = 0.0
+ loss = dis_loss + dir_loss
+ return loss
+
+
+def calc_pose_error(pose1:np.ndarray, pose2:np.ndarray) -> Tuple[float, float]:
+ error_matrix = np.dot(np.linalg.inv(pose1), pose2)
+ translation_error = error_matrix[:3, 3]
+ rotation_error = np.arccos(np.clip((np.trace(error_matrix[:3, :3]) - 1) / 2, -1, 1))
+ return (np.linalg.norm(translation_error), rotation_error)
+
+
+def invaffordance_metrics(grasp_translation:np.ndarray, grasp_rotation:np.ndarray, grasp_score:float, affordable_position:np.ndarray,
+ joint_base:np.ndarray, joint_direction:np.ndarray, joint_type:int) -> Tuple[float, float, float]:
+ if joint_type == 0:
+ # revolute
+ l2_dist = np.linalg.norm(grasp_translation - affordable_position)
+ # normal = np.cross(joint_base - affordable_position, joint_base + joint_direction - affordable_position)
+ # normal = normal / np.linalg.norm(normal)
+ # plane_dist = abs(np.dot(grasp_translation - affordable_position, normal))
+ # return (1.0 - grasp_score,)
+ # return (l2_dist, plane_dist, 1.0 - grasp_score)
+ return (l2_dist,)
+ elif joint_type == 1:
+ # prismatic
+ l2_dist = np.linalg.norm(grasp_translation - affordable_position)
+ # plane_dist = abs(np.dot(grasp_translation - affordable_position, joint_direction))
+ # return (1.0 - grasp_score,)
+ # return (l2_dist, plane_dist, 1.0 - grasp_score)
+ return (l2_dist,)
+ else:
+ raise ValueError(f"Invalid joint_type: {joint_type}")
+
+def invaffordances2affordance(invaffordances:List[Tuple[float, float, float]], sort:bool=False) -> List[float]:
+ if len(invaffordances) == 0:
+ return []
+ if len(invaffordances) == 1:
+ return [1.0]
+ # from list of tuple to dict of list, k: metrics index, v: invaffordance list
+ invaffordances_dict = {k: [v[k] for v in invaffordances] for k in range(len(invaffordances[0]))}
+ if sort:
+ # sort each list in dict
+ affordances_dict = {}
+ for k in invaffordances_dict.keys():
+ affordances_dict[k] = len(invaffordances_dict[k]) - 1 - np.argsort(np.argsort(invaffordances_dict[k]))
+ affordances_dict[k] /= len(invaffordances) - 1
+ else:
+ # normalize each list in dict
+ affordances_dict = {}
+ for k in invaffordances_dict.keys():
+ affordances_dict[k] = (np.max(invaffordances_dict[k]) - invaffordances_dict[k]) / (np.max(invaffordances_dict[k]) - np.min(invaffordances_dict[k]))
+ # merge into list
+ affordances = np.array([0.0 for _ in range(len(invaffordances))])
+ for k in affordances_dict.keys():
+ affordances += affordances_dict[k]
+ affordances /= len(affordances_dict.keys())
+ return affordances.tolist()
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value
+ Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
+ """
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def log_metrics(results_dict:Dict[str, Union[np.ndarray, int, List[str]]], logger:Logger, log_path:str,
+ tb_writer:Optional[SummaryWriter]=None, epoch:Optional[int]=None, split:Optional[str]=None) -> None:
+ translation_distance_errors = results_dict['translation_distance_errors']
+ translation_along_errors = results_dict['translation_along_errors']
+ translation_perp_errors = results_dict['translation_perp_errors']
+ translation_plane_errors = results_dict['translation_plane_errors']
+ translation_line_errors = results_dict['translation_line_errors']
+ translation_outliers_num = results_dict['translation_outliers_num']
+ rotation_errors = results_dict['rotation_errors']
+ rotation_outliers_num = results_dict['rotation_outliers_num']
+ if 'affordance_errors' in results_dict.keys():
+ affordance_errors = results_dict['affordance_errors']
+ affordance_outliers_num = results_dict['affordance_outliers_num']
+ else:
+ affordance_errors = [0, 0]
+ affordance_outliers_num = 0
+
+ if len(translation_distance_errors.shape) == 1:
+ data_num = translation_distance_errors.shape[0]
+ if 'names' in results_dict.keys():
+ names = results_dict['names']
+ dataframe = pd.DataFrame({'name': names,
+ 'translation_distance_errors': translation_distance_errors,
+ 'translation_along_errors': translation_along_errors,
+ 'translation_perp_errors': translation_perp_errors,
+ 'translation_plane_errors': translation_plane_errors,
+ 'translation_line_errors': translation_line_errors,
+ 'rotation_errors': rotation_errors,
+ 'affordance_errors': affordance_errors})
+ dataframe.to_csv(os.path.join(log_path, 'metrics.csv'), sep=',')
+ else:
+ pass
+
+ # mean
+ mean_translation_distance_error = np.mean(translation_distance_errors, axis=0)
+ mean_translation_along_error = np.mean(translation_along_errors, axis=0)
+ mean_translation_perp_error = np.mean(translation_perp_errors, axis=0)
+ mean_translation_plane_error = np.mean(translation_plane_errors, axis=0)
+ mean_translation_line_error = np.mean(translation_line_errors, axis=0)
+ mean_rotation_error = np.mean(rotation_errors, axis=0)
+ mean_affordance_error = np.mean(affordance_errors, axis=0)
+ logger.info(f"mean_translation_distance_error = {mean_translation_distance_error}")
+ logger.info(f"mean_translation_along_error = {mean_translation_along_error}")
+ logger.info(f"mean_translation_perp_error = {mean_translation_perp_error}")
+ logger.info(f"mean_translation_plane_error = {mean_translation_plane_error}")
+ logger.info(f"mean_translation_line_error = {mean_translation_line_error}")
+ logger.info(f"mean_rotation_error = {mean_rotation_error}")
+ logger.info(f"mean_affordance_error = {mean_affordance_error}")
+ # median
+ median_translation_distance_error = np.median(translation_distance_errors, axis=0)
+ median_translation_along_error = np.median(translation_along_errors, axis=0)
+ median_translation_perp_error = np.median(translation_perp_errors, axis=0)
+ median_translation_plane_error = np.median(translation_plane_errors, axis=0)
+ median_translation_line_error = np.median(translation_line_errors, axis=0)
+ median_rotation_error = np.median(rotation_errors, axis=0)
+ median_affordance_error = np.median(affordance_errors, axis=0)
+ logger.info(f"median_translation_distance_error = {median_translation_distance_error}")
+ logger.info(f"median_translation_along_error = {median_translation_along_error}")
+ logger.info(f"median_translation_perp_error = {median_translation_perp_error}")
+ logger.info(f"median_translation_plane_error = {median_translation_plane_error}")
+ logger.info(f"median_translation_line_error = {median_translation_line_error}")
+ logger.info(f"median_rotation_error = {median_rotation_error}")
+ logger.info(f"median_affordance_error = {median_affordance_error}")
+ # max
+ max_translation_distance_error = np.max(translation_distance_errors, axis=0)
+ max_translation_along_error = np.max(translation_along_errors, axis=0)
+ max_translation_perp_error = np.max(translation_perp_errors, axis=0)
+ max_translation_plane_error = np.max(translation_plane_errors, axis=0)
+ max_translation_line_error = np.max(translation_line_errors, axis=0)
+ max_rotation_error = np.max(rotation_errors, axis=0)
+ max_affordance_error = np.max(affordance_errors, axis=0)
+ logger.info(f"max_translation_distance_error = {max_translation_distance_error}")
+ logger.info(f"max_translation_along_error = {max_translation_along_error}")
+ logger.info(f"max_translation_perp_error = {max_translation_perp_error}")
+ logger.info(f"max_translation_plane_error = {max_translation_plane_error}")
+ logger.info(f"max_translation_line_error = {max_translation_line_error}")
+ logger.info(f"max_rotation_error = {max_rotation_error}")
+ logger.info(f"max_affordance_error = {max_affordance_error}")
+ # min
+ min_translation_distance_error = np.min(translation_distance_errors, axis=0)
+ min_translation_along_error = np.min(translation_along_errors, axis=0)
+ min_translation_perp_error = np.min(translation_perp_errors, axis=0)
+ min_translation_plane_error = np.min(translation_plane_errors, axis=0)
+ min_translation_line_error = np.min(translation_line_errors, axis=0)
+ min_rotation_error = np.min(rotation_errors, axis=0)
+ min_affordance_error = np.min(affordance_errors, axis=0)
+ logger.info(f"min_translation_distance_error = {min_translation_distance_error}")
+ logger.info(f"min_translation_along_error = {min_translation_along_error}")
+ logger.info(f"min_translation_perp_error = {min_translation_perp_error}")
+ logger.info(f"min_translation_plane_error = {min_translation_plane_error}")
+ logger.info(f"min_translation_line_error = {min_translation_line_error}")
+ logger.info(f"min_rotation_error = {min_rotation_error}")
+ logger.info(f"min_affordance_error = {min_affordance_error}")
+ # std
+ std_translation_distance_error = np.std(translation_distance_errors, axis=0)
+ std_translation_along_error = np.std(translation_along_errors, axis=0)
+ std_translation_perp_error = np.std(translation_perp_errors, axis=0)
+ std_translation_plane_error = np.std(translation_plane_errors, axis=0)
+ std_translation_line_error = np.std(translation_line_errors, axis=0)
+ std_rotation_error = np.std(rotation_errors, axis=0)
+ std_affordance_error = np.std(affordance_errors, axis=0)
+ logger.info(f"std_translation_distance_error = {std_translation_distance_error}")
+ logger.info(f"std_translation_along_error = {std_translation_along_error}")
+ logger.info(f"std_translation_perp_error = {std_translation_perp_error}")
+ logger.info(f"std_translation_plane_error = {std_translation_plane_error}")
+ logger.info(f"std_translation_line_error = {std_translation_line_error}")
+ logger.info(f"std_rotation_error = {std_rotation_error}")
+ logger.info(f"std_affordance_error = {std_affordance_error}")
+ # outliers
+ translation_outliers_ratio = translation_outliers_num / data_num
+ rotation_outliers_ratio = rotation_outliers_num / data_num
+ affordance_outliers_ratio = affordance_outliers_num / data_num
+ logger.info(f"translation_outliers_num = {translation_outliers_num}")
+ logger.info(f"rotation_outliers_num = {rotation_outliers_num}")
+ logger.info(f"affordance_outliers_num = {affordance_outliers_num}")
+ logger.info(f"translation_outliers_ratio = {translation_outliers_ratio}")
+ logger.info(f"rotation_outliers_ratio = {rotation_outliers_ratio}")
+ logger.info(f"affordance_outliers_ratio = {affordance_outliers_ratio}")
+
+ logger.info(f"data_num = {data_num}")
+
+ if tb_writer is not None:
+ tb_writer.add_scalars(f'{split}/translation_distance_error', {
+ 'mean': mean_translation_distance_error.item(),
+ 'median': median_translation_distance_error.item(),
+ 'max': max_translation_distance_error.item(),
+ 'min': min_translation_distance_error.item(),
+ 'std': std_translation_distance_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/translation_along_error', {
+ 'mean': mean_translation_along_error.item(),
+ 'median': median_translation_along_error.item(),
+ 'max': max_translation_along_error.item(),
+ 'min': min_translation_along_error.item(),
+ 'std': std_translation_along_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/translation_perp_error', {
+ 'mean': mean_translation_perp_error.item(),
+ 'median': median_translation_perp_error.item(),
+ 'max': max_translation_perp_error.item(),
+ 'min': min_translation_perp_error.item(),
+ 'std': std_translation_perp_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/translation_plane_error', {
+ 'mean': mean_translation_plane_error.item(),
+ 'median': median_translation_plane_error.item(),
+ 'max': max_translation_plane_error.item(),
+ 'min': min_translation_plane_error.item(),
+ 'std': std_translation_plane_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/translation_line_error', {
+ 'mean': mean_translation_line_error.item(),
+ 'median': median_translation_line_error.item(),
+ 'max': max_translation_line_error.item(),
+ 'min': min_translation_line_error.item(),
+ 'std': std_translation_line_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/rotation_error', {
+ 'mean': mean_rotation_error.item(),
+ 'median': median_rotation_error.item(),
+ 'max': max_rotation_error.item(),
+ 'min': min_rotation_error.item(),
+ 'std': std_rotation_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/affordance_error', {
+ 'mean': mean_affordance_error.item(),
+ 'median': median_affordance_error.item(),
+ 'max': max_affordance_error.item(),
+ 'min': min_affordance_error.item(),
+ 'std': std_affordance_error.item()
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/outliers_ratio', {
+ 'translation': translation_outliers_ratio,
+ 'rotation': rotation_outliers_ratio,
+ 'affordance': affordance_outliers_ratio
+ }, epoch)
+ else:
+ pass
+ elif len(translation_distance_errors.shape) == 2:
+ data_num = translation_distance_errors.shape[0]
+ joint_num = translation_distance_errors.shape[1]
+ if 'names' in results_dict.keys():
+ names = results_dict['names']
+ for j in range(joint_num):
+ dataframe = pd.DataFrame({'name': names,
+ 'translation_distance_errors': translation_distance_errors[:, j],
+ 'translation_along_errors': translation_along_errors[:, j],
+ 'translation_perp_errors': translation_perp_errors[:, j],
+ 'translation_plane_errors': translation_plane_errors[:, j],
+ 'translation_line_errors': translation_line_errors[:, j],
+ 'rotation_errors': rotation_errors[:, j],
+ 'affordance_errors': affordance_errors[:, j]})
+ dataframe.to_csv(os.path.join(log_path, f'metrics{j}.csv'), sep=',')
+ else:
+ pass
+
+ # mean
+ mean_translation_distance_error = np.mean(translation_distance_errors, axis=0)
+ mean_translation_along_error = np.mean(translation_along_errors, axis=0)
+ mean_translation_perp_error = np.mean(translation_perp_errors, axis=0)
+ mean_translation_plane_error = np.mean(translation_plane_errors, axis=0)
+ mean_translation_line_error = np.mean(translation_line_errors, axis=0)
+ mean_rotation_error = np.mean(rotation_errors, axis=0)
+ mean_affordance_error = np.mean(affordance_errors, axis=0)
+ logger.info(f"mean_translation_distance_error = {mean_translation_distance_error}")
+ logger.info(f"mean_translation_along_error = {mean_translation_along_error}")
+ logger.info(f"mean_translation_perp_error = {mean_translation_perp_error}")
+ logger.info(f"mean_translation_plane_error = {mean_translation_plane_error}")
+ logger.info(f"mean_translation_line_error = {mean_translation_line_error}")
+ logger.info(f"mean_rotation_error = {mean_rotation_error}")
+ logger.info(f"mean_affordance_error = {mean_affordance_error}")
+ # median
+ median_translation_distance_error = np.median(translation_distance_errors, axis=0)
+ median_translation_along_error = np.median(translation_along_errors, axis=0)
+ median_translation_perp_error = np.median(translation_perp_errors, axis=0)
+ median_translation_plane_error = np.median(translation_plane_errors, axis=0)
+ median_translation_line_error = np.median(translation_line_errors, axis=0)
+ median_rotation_error = np.median(rotation_errors, axis=0)
+ median_affordance_error = np.median(affordance_errors, axis=0)
+ logger.info(f"median_translation_distance_error = {median_translation_distance_error}")
+ logger.info(f"median_translation_along_error = {median_translation_along_error}")
+ logger.info(f"median_translation_perp_error = {median_translation_perp_error}")
+ logger.info(f"median_translation_plane_error = {median_translation_plane_error}")
+ logger.info(f"median_translation_line_error = {median_translation_line_error}")
+ logger.info(f"median_rotation_error = {median_rotation_error}")
+ logger.info(f"median_affordance_error = {median_affordance_error}")
+ # max
+ max_translation_distance_error = np.max(translation_distance_errors, axis=0)
+ max_translation_along_error = np.max(translation_along_errors, axis=0)
+ max_translation_perp_error = np.max(translation_perp_errors, axis=0)
+ max_translation_plane_error = np.max(translation_plane_errors, axis=0)
+ max_translation_line_error = np.max(translation_line_errors, axis=0)
+ max_rotation_error = np.max(rotation_errors, axis=0)
+ max_affordance_error = np.max(affordance_errors, axis=0)
+ logger.info(f"max_translation_distance_error = {max_translation_distance_error}")
+ logger.info(f"max_translation_along_error = {max_translation_along_error}")
+ logger.info(f"max_translation_perp_error = {max_translation_perp_error}")
+ logger.info(f"max_translation_plane_error = {max_translation_plane_error}")
+ logger.info(f"max_translation_line_error = {max_translation_line_error}")
+ logger.info(f"max_rotation_error = {max_rotation_error}")
+ logger.info(f"max_affordance_error = {max_affordance_error}")
+ # min
+ min_translation_distance_error = np.min(translation_distance_errors, axis=0)
+ min_translation_along_error = np.min(translation_along_errors, axis=0)
+ min_translation_perp_error = np.min(translation_perp_errors, axis=0)
+ min_translation_plane_error = np.min(translation_plane_errors, axis=0)
+ min_translation_line_error = np.min(translation_line_errors, axis=0)
+ min_rotation_error = np.min(rotation_errors, axis=0)
+ min_affordance_error = np.min(affordance_errors, axis=0)
+ logger.info(f"min_translation_distance_error = {min_translation_distance_error}")
+ logger.info(f"min_translation_along_error = {min_translation_along_error}")
+ logger.info(f"min_translation_perp_error = {min_translation_perp_error}")
+ logger.info(f"min_translation_plane_error = {min_translation_plane_error}")
+ logger.info(f"min_translation_line_error = {min_translation_line_error}")
+ logger.info(f"min_rotation_error = {min_rotation_error}")
+ logger.info(f"min_affordance_error = {min_affordance_error}")
+ # std
+ std_translation_distance_error = np.std(translation_distance_errors, axis=0)
+ std_translation_along_error = np.std(translation_along_errors, axis=0)
+ std_translation_perp_error = np.std(translation_perp_errors, axis=0)
+ std_translation_plane_error = np.std(translation_plane_errors, axis=0)
+ std_translation_line_error = np.std(translation_line_errors, axis=0)
+ std_rotation_error = np.std(rotation_errors, axis=0)
+ std_affordance_error = np.std(affordance_errors, axis=0)
+ logger.info(f"std_translation_distance_error = {std_translation_distance_error}")
+ logger.info(f"std_translation_along_error = {std_translation_along_error}")
+ logger.info(f"std_translation_perp_error = {std_translation_perp_error}")
+ logger.info(f"std_translation_plane_error = {std_translation_plane_error}")
+ logger.info(f"std_translation_line_error = {std_translation_line_error}")
+ logger.info(f"std_rotation_error = {std_rotation_error}")
+ logger.info(f"std_affordance_error = {std_affordance_error}")
+ # outliers
+ translation_outliers_ratio = translation_outliers_num / (data_num * joint_num)
+ rotation_outliers_ratio = rotation_outliers_num / (data_num * joint_num)
+ affordance_outliers_ratio = affordance_outliers_num / (data_num * joint_num)
+ logger.info(f"translation_outliers_num = {translation_outliers_num}")
+ logger.info(f"rotation_outliers_num = {rotation_outliers_num}")
+ logger.info(f"affordance_outliers_num = {affordance_outliers_num}")
+ logger.info(f"translation_outliers_ratio = {translation_outliers_ratio}")
+ logger.info(f"rotation_outliers_ratio = {rotation_outliers_ratio}")
+ logger.info(f"affordance_outliers_ratio = {affordance_outliers_ratio}")
+
+ logger.info(f"data_num = {data_num}, joint_num = {joint_num}")
+
+ if tb_writer is not None:
+ for j in range(joint_num):
+ tb_writer.add_scalars(f'{split}/joint_{j}/translation_distance_error', {
+ 'mean': mean_translation_distance_error[j],
+ 'median': median_translation_distance_error[j],
+ 'max': max_translation_distance_error[j],
+ 'min': min_translation_distance_error[j],
+ 'std': std_translation_distance_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/joint_{j}/translation_along_error', {
+ 'mean': mean_translation_along_error[j],
+ 'median': median_translation_along_error[j],
+ 'max': max_translation_along_error[j],
+ 'min': min_translation_along_error[j],
+ 'std': std_translation_along_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/joint_{j}/translation_perp_error', {
+ 'mean': mean_translation_perp_error[j],
+ 'median': median_translation_perp_error[j],
+ 'max': max_translation_perp_error[j],
+ 'min': min_translation_perp_error[j],
+ 'std': std_translation_perp_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/joint_{j}/translation_plane_error', {
+ 'mean': mean_translation_plane_error[j],
+ 'median': median_translation_plane_error[j],
+ 'max': max_translation_plane_error[j],
+ 'min': min_translation_plane_error[j],
+ 'std': std_translation_plane_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/joint_{j}/translation_line_error', {
+ 'mean': mean_translation_line_error[j],
+ 'median': median_translation_line_error[j],
+ 'max': max_translation_line_error[j],
+ 'min': min_translation_line_error[j],
+ 'std': std_translation_line_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/joint_{j}/rotation_error', {
+ 'mean': mean_rotation_error[j],
+ 'median': median_rotation_error[j],
+ 'max': max_rotation_error[j],
+ 'min': min_rotation_error[j],
+ 'std': std_rotation_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/joint_{j}/affordance_error', {
+ 'mean': mean_affordance_error[j],
+ 'median': median_affordance_error[j],
+ 'max': max_affordance_error[j],
+ 'min': min_affordance_error[j],
+ 'std': std_affordance_error[j]
+ }, epoch)
+ tb_writer.add_scalars(f'{split}/outliers_ratio', {
+ 'translation': translation_outliers_ratio,
+ 'rotation': rotation_outliers_ratio,
+ 'affordance': affordance_outliers_ratio
+ }, epoch)
+ else:
+ pass
+ else:
+ raise ValueError(f"Invalid shape of translation_distance_errors: {translation_distance_errors.shape}")
diff --git a/utilities/model_utils.py b/utilities/model_utils.py
new file mode 100644
index 0000000..97dcd3f
--- /dev/null
+++ b/utilities/model_utils.py
@@ -0,0 +1,4 @@
+def inplace_relu(m):
+ classname = m.__class__.__name__
+ if classname.find('ReLU') != -1:
+ m.inplace=True
diff --git a/utilities/network_utils.py b/utilities/network_utils.py
new file mode 100644
index 0000000..6e7a175
--- /dev/null
+++ b/utilities/network_utils.py
@@ -0,0 +1,28 @@
+from typing import Optional
+from paramiko import SSHClient
+from scp import SCPClient
+
+
+def send(local_path:str, remote_path:str,
+ remote_ip:str, port:int=22, username:Optional[str]=None, password:Optional[str]=None, key_filename:Optional[str]=None) -> None:
+ ssh = SSHClient()
+ ssh.load_system_host_keys()
+ ssh.connect(remote_ip, port=port, username=username, password=password, key_filename=key_filename)
+
+ scp = SCPClient(ssh.get_transport())
+
+ scp.put(local_path, remote_path)
+
+ scp.close()
+
+def read(local_path:str, remote_path:str,
+ remote_ip:str, port:int=22, username:Optional[str]=None, password:Optional[str]=None, key_filename:Optional[str]=None) -> None:
+ ssh = SSHClient()
+ ssh.load_system_host_keys()
+ ssh.connect(remote_ip, port=port, username=username, password=password, key_filename=key_filename)
+
+ scp = SCPClient(ssh.get_transport())
+
+ scp.get(remote_path, local_path)
+
+ scp.close()
diff --git a/utilities/vis_utils.py b/utilities/vis_utils.py
new file mode 100644
index 0000000..d51f08d
--- /dev/null
+++ b/utilities/vis_utils.py
@@ -0,0 +1,481 @@
+from typing import Tuple, Union, Optional
+import numpy as np
+import open3d as o3d
+import matplotlib.cm as cm
+
+from .constants import EPS
+
+
+def visualize(pc:np.ndarray, pc_color:Optional[np.ndarray]=None, pc_normal:Optional[np.ndarray]=None,
+ joint_translations:Optional[np.ndarray]=None, joint_rotations:Optional[np.ndarray]=None, affordable_positions:Optional[np.ndarray]=None,
+ joint_axis_colors:Optional[np.ndarray]=None, joint_point_colors:Optional[np.ndarray]=None, affordable_position_colors:Optional[np.ndarray]=None,
+ grasp_poses:Optional[np.ndarray]=None, grasp_widths:Optional[np.ndarray]=None, grasp_depths:Optional[np.ndarray]=None, grasp_affordances:Optional[np.ndarray]=None,
+ whether_frame:bool=True, whether_bbox:Union[bool, Tuple[np.ndarray, np.ndarray]]=True, window_name:str='visualization') -> None:
+ geometries = []
+
+ if whether_frame:
+ frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ geometries.append(frame)
+ else:
+ pass
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ if pc_color is None:
+ pass
+ elif len(pc_color.shape) == 1:
+ pcd.paint_uniform_color(pc_color)
+ elif len(pc_color.shape) == 2:
+ pcd.colors = o3d.utility.Vector3dVector(pc_color)
+ else:
+ raise NotImplementedError
+ if pc_normal is None:
+ pass
+ else:
+ pcd.normals = o3d.utility.Vector3dVector(pc_normal)
+ geometries.append(pcd)
+
+ if isinstance(whether_bbox, bool) and whether_bbox:
+ bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points)
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ elif isinstance(whether_bbox, tuple):
+ bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=whether_bbox[0], max_bound=whether_bbox[1])
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ else:
+ pass
+
+ joint_num = joint_translations.shape[0] if joint_translations is not None else 0
+ for j in range(joint_num):
+ joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1)
+ rotation = np.zeros((3, 3))
+ temp2 = np.cross(joint_rotations[j], np.array([1., 0., 0.]))
+ if np.linalg.norm(temp2) < EPS:
+ temp1 = np.cross(np.array([0., 1., 0.]), joint_rotations[j])
+ temp1 /= np.linalg.norm(temp1)
+ temp2 = np.cross(joint_rotations[j], temp1)
+ temp2 /= np.linalg.norm(temp2)
+ else:
+ temp2 /= np.linalg.norm(temp2)
+ temp1 = np.cross(temp2, joint_rotations[j])
+ temp1 /= np.linalg.norm(temp1)
+ rotation[:, 0] = temp1
+ rotation[:, 1] = temp2
+ rotation[:, 2] = joint_rotations[j]
+ joint_axis.rotate(rotation, np.array([[0], [0], [0]]))
+ joint_axis.translate(joint_translations[j].reshape((3, 1)))
+ if joint_axis_colors is None:
+ pass
+ elif len(joint_axis_colors.shape) == 1:
+ joint_axis.paint_uniform_color(joint_axis_colors)
+ elif len(joint_axis_colors.shape) == 2:
+ joint_axis.paint_uniform_color(joint_axis_colors[j])
+ else:
+ raise NotImplementedError
+ geometries.append(joint_axis)
+ joint_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.015)
+ joint_point = joint_point.translate(joint_translations[j].reshape((3, 1)))
+ if joint_point_colors is None:
+ pass
+ elif len(joint_point_colors.shape) == 1:
+ joint_point.paint_uniform_color(joint_point_colors)
+ elif len(joint_point_colors.shape) == 2:
+ joint_point.paint_uniform_color(joint_point_colors[j])
+ else:
+ raise NotImplementedError
+ geometries.append(joint_point)
+
+ if affordable_positions is not None:
+ affordable_position_num = affordable_positions.shape[0]
+ for i in range(affordable_position_num):
+ affordable_position = o3d.geometry.TriangleMesh.create_sphere(radius=0.015)
+ affordable_position = affordable_position.translate(affordable_positions[i].reshape((3, 1)))
+ if affordable_position_colors is None:
+ pass
+ elif len(affordable_position_colors.shape) == 1:
+ affordable_position.paint_uniform_color(affordable_position_colors)
+ elif len(affordable_position_colors.shape) == 2:
+ affordable_position.paint_uniform_color(affordable_position_colors[i])
+ else:
+ raise NotImplementedError
+ geometries.append(affordable_position)
+
+ if grasp_poses is not None:
+ grasp_num = grasp_poses.shape[0]
+ for g in range(grasp_num):
+ grasp_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
+ grasp_frame.transform(grasp_poses[g])
+ geometries.append(grasp_frame)
+
+ finger_width = 0.004
+ tail_length = 0.04
+ depth_base = 0.02
+ gg_width = grasp_widths[g]
+ gg_depth = grasp_depths[g]
+ gg_affordance = grasp_affordances[g]
+ gg_translation = grasp_poses[g][:3, 3]
+ gg_rotation = grasp_poses[g][:3, :3]
+
+ left = np.zeros((2, 3))
+ left[0] = np.array([-depth_base - finger_width, -gg_width / 2, 0])
+ left[1] = np.array([gg_depth, -gg_width / 2, 0])
+
+ right = np.zeros((2, 3))
+ right[0] = np.array([-depth_base - finger_width, gg_width / 2, 0])
+ right[1] = np.array([gg_depth, gg_width / 2, 0])
+
+ bottom = np.zeros((2, 3))
+ bottom[0] = np.array([-finger_width - depth_base, -gg_width / 2, 0])
+ bottom[1] = np.array([-finger_width - depth_base, gg_width / 2, 0])
+
+ tail = np.zeros((2, 3))
+ tail[0] = np.array([-(tail_length + finger_width + depth_base), 0, 0])
+ tail[1] = np.array([-(finger_width + depth_base), 0, 0])
+
+ vertices = np.vstack([left, right, bottom, tail])
+ vertices = np.dot(gg_rotation, vertices.T).T + gg_translation
+
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(vertices)
+ line_set.lines = o3d.utility.Vector2iVector([[0, 1], [2, 3], [4, 5], [6, 7]])
+ if gg_affordance < 0.5:
+ line_set.paint_uniform_color([1, 2*gg_affordance, 0])
+ elif gg_affordance == 1.0:
+ line_set.paint_uniform_color([0., 0., 1.])
+ else:
+ line_set.paint_uniform_color([-2*gg_affordance+2, 1, 0])
+ geometries.append(line_set)
+
+ o3d.visualization.draw_geometries(geometries, point_show_normal=pc_normal is not None, window_name=window_name)
+
+def visualize_mask(pc:np.ndarray, instance_mask:np.ndarray, function_mask:np.ndarray,
+ pc_normal:Optional[np.ndarray]=None,
+ joint_translations:Optional[np.ndarray]=None, joint_rotations:Optional[np.ndarray]=None, affordable_positions:Optional[np.ndarray]=None,
+ whether_frame:bool=True, whether_bbox:Union[bool, Tuple[np.ndarray, np.ndarray]]=True, window_name:str='visualization') -> None:
+ geometries = []
+
+ if whether_frame:
+ frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ geometries.append(frame)
+ else:
+ pass
+
+ instance_ids = np.unique(instance_mask) # [0, J]
+ functions = []
+ for instance_id in instance_ids:
+ instance_function = function_mask[instance_mask == instance_id]
+ functions.append(np.unique(instance_function)[0])
+ revolute_num, prismatic_num, fixed_num = 0, 0, 0
+ for f in functions:
+ if f == 0:
+ revolute_num += 1
+ elif f == 1:
+ prismatic_num += 1
+ elif f == 2:
+ fixed_num += 1
+ else:
+ raise ValueError(f"Invalid function {f}")
+ # assert fixed_num == 1
+ revolute_gradient = 1.0 / revolute_num if revolute_num > 0 else 0.0
+ prismatic_gradient = 1.0 / prismatic_num if prismatic_num > 0 else 0.0
+
+ pc_color = np.zeros((pc.shape[0], 3), dtype=np.float32)
+ revolute_num, prismatic_num = 0, 0
+ for instance_idx, instance_id in enumerate(instance_ids):
+ if functions[instance_idx] == 0:
+ pc_color[instance_mask == instance_id] = np.array([1. - revolute_gradient * revolute_num, 0., 0.])
+ revolute_num += 1
+ elif functions[instance_idx] == 1:
+ pc_color[instance_mask == instance_id] = np.array([0., 1. - prismatic_gradient * prismatic_num, 0.])
+ prismatic_num += 1
+ elif functions[instance_idx] == 2:
+ pc_color[instance_mask == instance_id] = np.array([0., 0., 0.])
+ else:
+ raise ValueError(f"Invalid function {functions[instance_idx]}")
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ pcd.colors = o3d.utility.Vector3dVector(pc_color)
+ if pc_normal is None:
+ pass
+ else:
+ pcd.normals = o3d.utility.Vector3dVector(pc_normal)
+ geometries.append(pcd)
+
+ if isinstance(whether_bbox, bool) and whether_bbox:
+ bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points)
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ elif isinstance(whether_bbox, tuple):
+ bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=whether_bbox[0], max_bound=whether_bbox[1])
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ else:
+ pass
+
+ joint_num = joint_translations.shape[0] if joint_translations is not None else 0
+ revolute_num, prismatic_num = 0, 0
+ for j in range(joint_num):
+ joint_function = functions[j+1]
+ if joint_function == 0:
+ joint_color = np.array([1. - revolute_gradient * revolute_num, 0., 0.])
+ revolute_num += 1
+ elif joint_function == 1:
+ joint_color = np.array([0., 1. - prismatic_gradient * prismatic_num, 0.])
+ prismatic_num += 1
+ else:
+ raise ValueError(f"Invalid function {joint_function}")
+ joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1)
+ rotation = np.zeros((3, 3))
+ temp2 = np.cross(joint_rotations[j], np.array([1., 0., 0.]))
+ if np.linalg.norm(temp2) < EPS:
+ temp1 = np.cross(np.array([0., 1., 0.]), joint_rotations[j])
+ temp1 /= np.linalg.norm(temp1)
+ temp2 = np.cross(joint_rotations[j], temp1)
+ temp2 /= np.linalg.norm(temp2)
+ else:
+ temp2 /= np.linalg.norm(temp2)
+ temp1 = np.cross(temp2, joint_rotations[j])
+ temp1 /= np.linalg.norm(temp1)
+ rotation[:, 0] = temp1
+ rotation[:, 1] = temp2
+ rotation[:, 2] = joint_rotations[j]
+ joint_axis.rotate(rotation, np.array([[0], [0], [0]]))
+ joint_axis.translate(joint_translations[j].reshape((3, 1)))
+ joint_axis.paint_uniform_color(joint_color)
+ geometries.append(joint_axis)
+ joint_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.015)
+ joint_point = joint_point.translate(joint_translations[j].reshape((3, 1)))
+ joint_point.paint_uniform_color(joint_color)
+ geometries.append(joint_point)
+ if affordable_positions is not None:
+ affordance_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.015)
+ affordance_point = affordance_point.translate(affordable_positions[j].reshape((3, 1)))
+ affordance_point.paint_uniform_color(joint_color)
+ geometries.append(affordance_point)
+
+ o3d.visualization.draw_geometries(geometries, point_show_normal=pc_normal is not None, window_name=window_name)
+
+
+def visualize_translation_voting(grid_pc:np.ndarray, votes_list:np.ndarray,
+ pc:Optional[np.ndarray]=None, pc_color:Optional[np.ndarray]=None,
+ gt_translation:Optional[np.ndarray]=None, gt_color:Optional[np.ndarray]=None,
+ pred_translation:Optional[np.ndarray]=None, pred_color:Optional[np.ndarray]=None,
+ show_threshold:float=0.5, whether_frame:bool=True, whether_bbox:bool=True,
+ window_name:str='visualization') -> None:
+ geometries = []
+
+ if whether_frame:
+ frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ geometries.append(frame)
+ else:
+ pass
+
+ grid_pcd = o3d.geometry.PointCloud()
+ votes_list = votes_list / np.max(votes_list)
+ grid_pc_color = np.zeros((grid_pc.shape[0], 3))
+ grid_pc_color = np.stack([np.ones_like(votes_list), 1-votes_list, 1-votes_list], axis=-1)
+ grid_pcd.points = o3d.utility.Vector3dVector(grid_pc[votes_list >= show_threshold])
+ grid_pcd.colors = o3d.utility.Vector3dVector(grid_pc_color[votes_list >= show_threshold])
+ geometries.append(grid_pcd)
+ print(grid_pc[votes_list >= show_threshold].shape[0])
+
+ if whether_bbox:
+ grid_bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(o3d.utility.Vector3dVector(grid_pc))
+ grid_bbox.color = np.array([0, 1, 0])
+ geometries.append(grid_bbox)
+ else:
+ pass
+
+ if gt_translation is None:
+ pass
+ else:
+ gt_point = o3d.geometry.PointCloud()
+ gt_point.points = o3d.utility.Vector3dVector(gt_translation[None, :])
+ if gt_color is None:
+ pass
+ else:
+ gt_point.paint_uniform_color(gt_color)
+ geometries.append(gt_point)
+
+ if pred_translation is None:
+ pass
+ else:
+ pred_point = o3d.geometry.PointCloud()
+ pred_point.points = o3d.utility.Vector3dVector(pred_translation[None, :])
+ if pred_color is None:
+ pass
+ else:
+ pred_point.paint_uniform_color(pred_color)
+ geometries.append(pred_point)
+
+ if pc is None:
+ pass
+ else:
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ if pc_color is None:
+ pass
+ elif len(pc_color.shape) == 1:
+ pcd.paint_uniform_color(pc_color)
+ elif len(pc_color.shape) == 2:
+ pcd.colors = o3d.utility.Vector3dVector(pc_color)
+ else:
+ raise NotImplementedError
+ geometries.append(pcd)
+
+ if whether_bbox:
+ bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points)
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ else:
+ pass
+
+ o3d.visualization.draw_geometries(geometries, window_name=window_name)
+
+def visualize_rotation_voting(sphere_pts:np.ndarray, votes:np.ndarray,
+ pc:Optional[np.ndarray]=None, pc_color:Optional[np.ndarray]=None,
+ gt_rotation:Optional[np.ndarray]=None, gt_color:Optional[np.ndarray]=None,
+ pred_rotation:Optional[np.ndarray]=None, pred_color:Optional[np.ndarray]=None,
+ show_threshold:float=0.5, whether_frame:bool=True, whether_bbox:bool=True,
+ window_name:str='visualization') -> None:
+ geometries = []
+
+ if whether_frame:
+ frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ geometries.append(frame)
+ else:
+ pass
+
+ joint_num = sphere_pts.shape[0]
+ votes = votes / np.max(votes)
+ print(votes[votes >= show_threshold].shape[0])
+ for j in range(joint_num):
+ if votes[j] < show_threshold:
+ continue
+ joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1)
+ rotation = np.zeros((3, 3))
+ temp2 = np.cross(sphere_pts[j], np.array([1., 0., 0.]))
+ if np.linalg.norm(temp2) < EPS:
+ temp1 = np.cross(np.array([0., 1., 0.]), sphere_pts[j])
+ temp1 /= np.linalg.norm(temp1)
+ temp2 = np.cross(sphere_pts[j], temp1)
+ temp2 /= np.linalg.norm(temp2)
+ else:
+ temp2 /= np.linalg.norm(temp2)
+ temp1 = np.cross(temp2, sphere_pts[j])
+ temp1 /= np.linalg.norm(temp1)
+ rotation[:, 0] = temp1
+ rotation[:, 1] = temp2
+ rotation[:, 2] = sphere_pts[j]
+ joint_axis.rotate(rotation, np.array([[0], [0], [0]]))
+ joint_axis.paint_uniform_color(np.array([1, 1-votes[j], 1-votes[j]]))
+ geometries.append(joint_axis)
+
+ if gt_rotation is None:
+ pass
+ else:
+ joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1)
+ rotation = np.zeros((3, 3))
+ temp2 = np.cross(gt_rotation, np.array([1., 0., 0.]))
+ if np.linalg.norm(temp2) < EPS:
+ temp1 = np.cross(np.array([0., 1., 0.]), gt_rotation)
+ temp1 /= np.linalg.norm(temp1)
+ temp2 = np.cross(gt_rotation, temp1)
+ temp2 /= np.linalg.norm(temp2)
+ else:
+ temp2 /= np.linalg.norm(temp2)
+ temp1 = np.cross(temp2, gt_rotation)
+ temp1 /= np.linalg.norm(temp1)
+ rotation[:, 0] = temp1
+ rotation[:, 1] = temp2
+ rotation[:, 2] = gt_rotation
+ joint_axis.rotate(rotation, np.array([[0], [0], [0]]))
+ if gt_color is None:
+ pass
+ else:
+ joint_axis.paint_uniform_color(gt_color)
+ geometries.append(joint_axis)
+
+ if pred_rotation is None:
+ pass
+ else:
+ joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1)
+ rotation = np.zeros((3, 3))
+ temp2 = np.cross(pred_rotation, np.array([1., 0., 0.]))
+ if np.linalg.norm(temp2) < EPS:
+ temp1 = np.cross(np.array([0., 1., 0.]), pred_rotation)
+ temp1 /= np.linalg.norm(temp1)
+ temp2 = np.cross(pred_rotation, temp1)
+ temp2 /= np.linalg.norm(temp2)
+ else:
+ temp2 /= np.linalg.norm(temp2)
+ temp1 = np.cross(temp2, pred_rotation)
+ temp1 /= np.linalg.norm(temp1)
+ rotation[:, 0] = temp1
+ rotation[:, 1] = temp2
+ rotation[:, 2] = pred_rotation
+ joint_axis.rotate(rotation, np.array([[0], [0], [0]]))
+ if pred_color is None:
+ pass
+ else:
+ joint_axis.paint_uniform_color(pred_color)
+ geometries.append(joint_axis)
+
+ if pc is None:
+ pass
+ else:
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ if pc_color is None:
+ pass
+ elif len(pc_color.shape) == 1:
+ pcd.paint_uniform_color(pc_color)
+ elif len(pc_color.shape) == 2:
+ pcd.colors = o3d.utility.Vector3dVector(pc_color)
+ else:
+ raise NotImplementedError
+ geometries.append(pcd)
+
+ if whether_bbox:
+ bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points)
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ else:
+ pass
+
+ o3d.visualization.draw_geometries(geometries, window_name=window_name)
+
+
+def visualize_confidence_voting(confs:np.ndarray, pc:np.ndarray, point_idxs:np.ndarray,
+ whether_frame:bool=True, whether_bbox:bool=True, window_name:str='visualization') -> None:
+ # confs: (N_t,), pc: (N, 3), point_idxs: (N_t, 2)
+ point_heats = np.zeros((pc.shape[0],)) # (N,)
+ for i in range(confs.shape[0]):
+ point_heats[point_idxs[i]] += confs[i]
+ print(np.max(point_heats), np.min(point_heats[point_heats > 0]), point_heats[point_heats > 0].shape[0])
+ point_heats /= np.max(point_heats)
+
+ geometries = []
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pc)
+ cmap = cm.get_cmap('jet')
+ pcd.colors = o3d.utility.Vector3dVector(cmap(point_heats)[:, :3])
+ geometries.append(pcd)
+
+ if whether_frame:
+ frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ geometries.append(frame)
+ else:
+ pass
+
+ if whether_bbox:
+ bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points)
+ bbox.color = np.array([1, 0, 0])
+ geometries.append(bbox)
+ else:
+ pass
+
+ o3d.visualization.draw_geometries(geometries, window_name=window_name)
diff --git a/weights/.gitkeep b/weights/.gitkeep
new file mode 100644
index 0000000..e69de29