diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b10163f --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +deps/* +!deps/.gitkeep +outputs/* +!outputs/.gitkeep +src_shot/build/* +**/__pycache__ +license/* +!license/.gitkeep +*.tar +*.so +.TimeRecord +imgui.ini +temp_data/* +!temp_data/.gitkeep diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..d8a1dd6 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 RPMArt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..aa78ded --- /dev/null +++ b/README.md @@ -0,0 +1,156 @@ +# RPMArt + + + Website + + + Paper + + + IROS + +
+ +teaser + +Official implementation for the paper [RPMArt: Towards Robust Perception and Manipulation for Articulated Objects](https://arxiv.org/abs/2403.16023), accepted by [IROS 2024](https://iros2024-abudhabi.org). + +For more information, please visit our [project website](https://r-pmart.github.io/). + +--- + +## 🛠 Installation +### 💻 Server-side +1. Clone this repo. + ```bash + git clone git@github.com:R-PMArt/rpmart.git + cd rpmart + ``` + +2. Create a [Conda](https://conda.org/) environment. + ```bash + conda create -n rpmart python=3.8 + conda activate rpmart + ``` + +3. Install [PyTorch](https://pytorch.org/). + ```bash + pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + ``` + +4. Install [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr). + ```bash + cd deps + git clone git@github.com:ildoonet/pytorch-gradual-warmup-lr.git + cd pytorch-gradual-warmup-lr + pip install . + cd ../.. + ``` + +5. Install [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine). + ```bash + conda install openblas-devel -c anaconda + export CUDA_HOME=/usr/local/cuda + pip install ninja + cd deps + git clone git@github.com:NVIDIA/MinkowskiEngine.git + cd MinkowskiEngine + pip install -U . --no-deps --install-option="--blas_include_dirs=${CONDA_PREFIX}/include" --install-option="--blas=openblas" + cd ../.. + ``` + +6. Install [CuPy](https://cupy.dev/). + ```bash + pip install cupy-cuda11x + ``` + +7. Install special [SAPIEN](https://sapien.ucsd.edu/). + ```bash + pip install http://download.cs.stanford.edu/orion/where2act/where2act_sapien_wheels/sapien-0.8.0.dev0-cp38-cp38-manylinux2014_x86_64.whl + ``` + +8. Install [AnyGrasp](https://github.com/graspnet/anygrasp_sdk). + ```bash + pip install cvxopt munch graspnetAPI + # follow AnyGrasp to use the licenses and weights and binary codes + ``` + +9. Install other dependencies. + ```bash + pip install -r requirements.txt + ``` + +10. Build `shot`. + ```bash + # you may need first install pybind11 and pcl + cd src_shot + mkdir build + cd build + cmake .. + make + cd ../.. + ``` + +### 🦾 Robot-side +1. Make sure `rt-linux` is enabled for Franka Emika Panda. + ```bash + uname -a + ``` + +2. Install `frankx` for robot and `pyrealsense2` for camera. + ```bash + pip install frankx pyrealsense2 + ``` + +3. Install `paramiko` for connecting with server. + ```bash + pip install paramiko + ``` + +## 🏃‍♂️ Run +1. Train or download [RoArtNet](https://huggingface.co/dadadadawjb/RoArtNet). + ```bash + bash scripts/train.sh + ``` + +2. Test RoArtNet. + ```bash + bash scripts/test.sh + ``` + +3. Evaluate RPMArt. + ```bash + bash scripts/eval_roartnet.sh + ``` + +4. Test RoArtNet on [RealArt-6](https://huggingface.co/datasets/dadadadawjb/RealArt-6). + ```bash + bash scripts/test_real.sh + ``` + +5. Evaluate RPMArt in the real world. + ```bash + # server side + bash scripts/real_service.sh + + # robot side + python real_eval.py + ``` + +## 🙏 Acknowledgement +* Our simulation environment is adapted from [VAT-Mart](https://github.com/warshallrho/VAT-Mart). +* Our voting module is adapted from [CPPF](https://github.com/qq456cvb/CPPF) and [CPPF++](https://github.com/qq456cvb/CPPF2). + +## ✍ Citation +If you find our work useful, please consider citing: +```bibtex +@article{wang2024rpmart, + title={RPMArt: Towards Robust Perception and Manipulation for Articulated Objects}, + author={Wang, Junbo and Liu, Wenhai and Yu, Qiaojun and You, Yang and Liu, Liu and Wang, Weiming and Lu, Cewu}, + journal={arXiv preprint arXiv:2403.16023}, + year={2024} +} +``` + +## 📃 License +This repository is released under the [MIT](https://mit-license.org/) license. diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000..9ea5386 Binary files /dev/null and b/assets/teaser.png differ diff --git a/configs/.gitkeep b/configs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/configs/algorithm/formal.yaml b/configs/algorithm/formal.yaml new file mode 100644 index 0000000..5f6aaed --- /dev/null +++ b/configs/algorithm/formal.yaml @@ -0,0 +1,29 @@ +sampling: + sample_tuples_num: 100000 + tuple_more_num: 3 + +shot_encoder: + hidden_dims: [128, 128, 128, 128, 128] + feature_dim: 64 + bn: False + ln: False # bn and ln can only be set one + dropout: 0 + +encoder: + hidden_dims: [128, 128, 128, 128, 128] + bn: False + ln: False # bn and ln can only be set one + dropout: 0 + +voting: + rot_bin_num: 36 # 5 degree + voting_num: 120 # 3 degree + angle_tol: 1.5 # 10 degree + # angle_tol: 0.35 # 5 degree + translation2pc: False + rotation_cluster: False + multi_candidate: False + candidate_threshold: 0.5 + rotation_multi_neighbor: False + neighbor_threshold: 10 + bmm_size: 100000 diff --git a/configs/data/camera_config.json b/configs/data/camera_config.json new file mode 100644 index 0000000..a04d979 --- /dev/null +++ b/configs/data/camera_config.json @@ -0,0 +1,15 @@ +{ + "intrinsics": { + "fovx": [70, 70], + "fovy": [40, 60], + "height": 480, + "width": 640, + "near": 0.1, + "far": 100.0 + }, + "extrinsics": { + "dist": [0.6, 1.2], + "phi": [0, 60], + "theta": [120, 240] + } +} \ No newline at end of file diff --git a/configs/data/object_config.json b/configs/data/object_config.json new file mode 100644 index 0000000..f8c2608 --- /dev/null +++ b/configs/data/object_config.json @@ -0,0 +1,30 @@ +{ + "scale_min": 0.8, + "scale_max": 1.1, + "Microwave": { + "size": 0.4, + "joint_num": 1 + }, + "StorageFurniture": { + "size": 0.375, + "joint_num": 2, + "joint_types": ["prismatic", "revolute"] + }, + "Refrigerator": { + "size": 0.4, + "joint_num": 1 + }, + "Safe": { + "size": 0.4, + "joint_num": 1 + }, + "WashingMachine": { + "size": 0.15, + "joint_num": 1 + }, + "Drawer": { + "size": 0.4, + "joint_num": 3, + "joint_types": ["prismatic", "prismatic", "prismatic"] + } +} \ No newline at end of file diff --git a/configs/dataset/formal_drawer.yaml b/configs/dataset/formal_drawer.yaml new file mode 100644 index 0000000..47ea190 --- /dev/null +++ b/configs/dataset/formal_drawer.yaml @@ -0,0 +1,13 @@ +train_path: '/data2/junbo/sapien4/train' +test_path: '/data2/junbo/sapien4/test' +train_categories: ['Drawer'] +test_categories: ['Drawer'] +joint_num: 3 +resolution: 2.5e-2 +# resolution: 1e-2 +receptive_field: 10 +normalize: 'bound' +# normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/formal_microwave.yaml b/configs/dataset/formal_microwave.yaml new file mode 100644 index 0000000..183ad92 --- /dev/null +++ b/configs/dataset/formal_microwave.yaml @@ -0,0 +1,13 @@ +train_path: '/data2/junbo/sapien4/train' +test_path: '/data2/junbo/sapien4/test' +train_categories: ['Microwave'] +test_categories: ['Microwave'] +joint_num: 1 +resolution: 2.5e-2 +# resolution: 1e-2 +receptive_field: 10 +normalize: 'bound' +# normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/formal_refrigerator.yaml b/configs/dataset/formal_refrigerator.yaml new file mode 100644 index 0000000..137d540 --- /dev/null +++ b/configs/dataset/formal_refrigerator.yaml @@ -0,0 +1,13 @@ +train_path: '/data2/junbo/sapien4/train' +test_path: '/data2/junbo/sapien4/test' +train_categories: ['Refrigerator'] +test_categories: ['Refrigerator'] +joint_num: 1 +resolution: 2.5e-2 +# resolution: 1e-2 +receptive_field: 10 +normalize: 'bound' +# normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/formal_safe.yaml b/configs/dataset/formal_safe.yaml new file mode 100644 index 0000000..48996ab --- /dev/null +++ b/configs/dataset/formal_safe.yaml @@ -0,0 +1,13 @@ +train_path: '/data2/junbo/sapien4/train' +test_path: '/data2/junbo/sapien4/test' +train_categories: ['Safe'] +test_categories: ['Safe'] +joint_num: 1 +resolution: 2.5e-2 +# resolution: 1e-2 +receptive_field: 10 +normalize: 'bound' +# normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/formal_storagefurniture.yaml b/configs/dataset/formal_storagefurniture.yaml new file mode 100644 index 0000000..679acb6 --- /dev/null +++ b/configs/dataset/formal_storagefurniture.yaml @@ -0,0 +1,13 @@ +train_path: '/data2/junbo/sapien4/train' +test_path: '/data2/junbo/sapien4/test' +train_categories: ['StorageFurniture'] +test_categories: ['StorageFurniture'] +joint_num: 2 +resolution: 2.5e-2 +# resolution: 1e-2 +receptive_field: 10 +normalize: 'bound' +# normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/formal_washingmachine.yaml b/configs/dataset/formal_washingmachine.yaml new file mode 100644 index 0000000..1c53fef --- /dev/null +++ b/configs/dataset/formal_washingmachine.yaml @@ -0,0 +1,13 @@ +train_path: '/data2/junbo/sapien4/train' +test_path: '/data2/junbo/sapien4/test' +train_categories: ['WashingMachine'] +test_categories: ['WashingMachine'] +joint_num: 1 +resolution: 2.5e-2 +# resolution: 1e-2 +receptive_field: 10 +normalize: 'bound' +# normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/real_with_with.yaml b/configs/dataset/real_with_with.yaml new file mode 100644 index 0000000..cf0e622 --- /dev/null +++ b/configs/dataset/real_with_with.yaml @@ -0,0 +1,11 @@ +path: '/data2/junbo/RealArt-6/with_table/microwave' +instances: ['0_with_chaos', '1_with_chaos', '2_with_chaos', '3_with_chaos', '4_with_chaos'] +joint_num: 1 +# resolution: 2.5e-2 +resolution: 1e-2 +receptive_field: 10 +# normalize: 'bound' +normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/real_with_without.yaml b/configs/dataset/real_with_without.yaml new file mode 100644 index 0000000..b0151c6 --- /dev/null +++ b/configs/dataset/real_with_without.yaml @@ -0,0 +1,11 @@ +path: '/data2/junbo/RealArt-6/with_table/microwave' +instances: ['0_without_chaos', '1_without_chaos', '2_without_chaos', '3_without_chaos', '4_without_chaos'] +joint_num: 1 +# resolution: 2.5e-2 +resolution: 1e-2 +receptive_field: 10 +# normalize: 'bound' +normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/real_without_with.yaml b/configs/dataset/real_without_with.yaml new file mode 100644 index 0000000..2602e7b --- /dev/null +++ b/configs/dataset/real_without_with.yaml @@ -0,0 +1,11 @@ +path: '/data2/junbo/RealArt-6/without_table/microwave' +instances: ['0_with_chaos', '1_with_chaos', '2_with_chaos', '3_with_chaos', '4_with_chaos'] +joint_num: 1 +# resolution: 2.5e-2 +resolution: 1e-2 +receptive_field: 10 +# normalize: 'bound' +normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/dataset/real_without_without.yaml b/configs/dataset/real_without_without.yaml new file mode 100644 index 0000000..1960c30 --- /dev/null +++ b/configs/dataset/real_without_without.yaml @@ -0,0 +1,11 @@ +path: '/data2/junbo/RealArt-6/without_table/microwave' +instances: ['0_without_chaos', '1_without_chaos', '2_without_chaos', '3_without_chaos', '4_without_chaos'] +joint_num: 1 +# resolution: 2.5e-2 +resolution: 1e-2 +receptive_field: 10 +# normalize: 'bound' +normalize: 'none' +sample_points_num: 1024 +rgb: False +denoise: False diff --git a/configs/eval_config.yaml b/configs/eval_config.yaml new file mode 100644 index 0000000..0a387c5 --- /dev/null +++ b/configs/eval_config.yaml @@ -0,0 +1,23 @@ +dataset: ??? +algorithm: + voting: + voting_num: 120 # 3 degree + # angle_tol: 1.5 # 10 degree + angle_tol: 0.35 # 5 degree + translation2pc: False + rotation_cluster: True + multi_candidate: True + candidate_threshold: 0.5 + rotation_multi_neighbor: True + neighbor_threshold: 5 + bmm_size: 100000 + +trained: + path: { + "Microwave": "./outputs/train/formal_microwave_2024_01_16_12_16_57", + "Refrigerator": "./outputs/train/formal_refrigerator_2024_01_17_22_52_26", + "Safe": "./outputs/train/formal_safe_2024_01_17_23_01_44", + "StorageFurniture": "./outputs/train/formal_storagefurniture_2024_01_18_11_13_43", + "Drawer": "./outputs/train/formal_drawer_2024_01_18_11_18_23", + "WashingMachine": "./outputs/train/formal_washingmachine_2024_01_17_23_07_14" + } diff --git a/configs/test_config.yaml b/configs/test_config.yaml new file mode 100644 index 0000000..e97e1df --- /dev/null +++ b/configs/test_config.yaml @@ -0,0 +1,34 @@ +dataset: + noise: True + distortion_rate: 0.1 + distortion_level: 0.01 + outlier_rate: 0.001 + outlier_level: 0.5 +algorithm: + voting: + voting_num: 120 # 3 degree + # angle_tol: 1.5 # 10 degree + angle_tol: 0.35 # 5 degree + translation2pc: False + rotation_cluster: True + multi_candidate: True + candidate_threshold: 0.5 + rotation_multi_neighbor: True + neighbor_threshold: 5 + bmm_size: 100000 +testing: + seed: 42 + device: 0 + batch_size: 16 + num_workers: 8 + training: False + +trained: + path: "./outputs/train/formal_microwave_2024_01_16_12_16_57" + +abbr: 'noise_microwave' +vis: False + +hydra: + run: + dir: "./outputs/test/${abbr}_${now:%Y_%m_%d_%H_%M_%S}" diff --git a/configs/test_gt_config.yaml b/configs/test_gt_config.yaml new file mode 100644 index 0000000..7b665c0 --- /dev/null +++ b/configs/test_gt_config.yaml @@ -0,0 +1,38 @@ +general: + seed: 42 + device: 0 + batch_size: 16 + num_workers: 8 + # abbr: 'none_norm_FTTTT' + abbr: 'formal_drawer' + +dataset: + path: "/data2/junbo/sapien4/test" + categories: ['Drawer'] + joint_num: 3 + # resolution: 2.5e-2 + resolution: 1e-2 + receptive_field: 10 + denoise: False + # normalize: 'median' + normalize: 'none' + sample_points_num: 1024 + +algorithm: + sample_tuples_num: 100000 + tuple_more_num: 0 + translation2pc: False + rotation_cluster: True + multi_candidate: True + candidate_threshold: 0.5 + rotation_multi_neighbor: True + neighbor_threshold: 5 + # angle_tol: 1.5 # 10 degree + angle_tol: 0.35 # 5 degree + voting_num: 120 # 3 degree + bmm_size: 100000 + + +hydra: + run: + dir: "./outputs/test_gt/${general.abbr}_${now:%Y_%m_%d_%H_%M_%S}" diff --git a/configs/test_real_config.yaml b/configs/test_real_config.yaml new file mode 100644 index 0000000..3512036 --- /dev/null +++ b/configs/test_real_config.yaml @@ -0,0 +1,31 @@ +defaults: + - dataset: real_without_without + - _self_ +algorithm: + voting: + voting_num: 120 # 3 degree + # angle_tol: 1.5 # 10 degree + angle_tol: 0.35 # 5 degree + translation2pc: False + rotation_cluster: True + multi_candidate: True + candidate_threshold: 0.5 + rotation_multi_neighbor: True + neighbor_threshold: 5 + bmm_size: 100000 +testing: + seed: 42 + device: 0 + batch_size: 16 + num_workers: 8 + training: False + +trained: + path: "./outputs/train/formal_microwave_2024_01_16_12_16_57" + +abbr: 'formal_microwave_without_without' +vis: False + +hydra: + run: + dir: "./outputs/test_real/${abbr}_${now:%Y_%m_%d_%H_%M_%S}" diff --git a/configs/train_config.yaml b/configs/train_config.yaml new file mode 100644 index 0000000..f0c2e76 --- /dev/null +++ b/configs/train_config.yaml @@ -0,0 +1,12 @@ +defaults: + - dataset: formal_microwave + - algorithm: formal + - training: formal + - _self_ + +abbr: 'formal_microwave' +debug: False + +hydra: + run: + dir: "./outputs/train/${abbr}_${now:%Y_%m_%d_%H_%M_%S}" diff --git a/configs/training/formal.yaml b/configs/training/formal.yaml new file mode 100644 index 0000000..d0670e4 --- /dev/null +++ b/configs/training/formal.yaml @@ -0,0 +1,21 @@ +seed: 42 +device: 0 +batch_size: 16 +num_workers: 8 +epoch_num: 200 # microwave +# epoch_num: 200 # storagefurniture +# epoch_num: 300 # refrigerator +# epoch_num: 120 # safe +# epoch_num: 200 # washingmachine +# epoch_num: 100 # drawer +lr: 1e-3 +weight_decay: 0 +lambda_rot: 0.1 +lambda_afford: 1.0 +lambda_conf: 0.5 +val_training: True +val_training_num: 64 +val_testing: True +val_testing_num: 64 +test_train: True +test_test: True diff --git a/datasets/point_tuple_dataset.py b/datasets/point_tuple_dataset.py new file mode 100644 index 0000000..8f29dfb --- /dev/null +++ b/datasets/point_tuple_dataset.py @@ -0,0 +1,257 @@ +from typing import List +import os +import json +import time +import itertools +from pathlib import Path +import tqdm +import random +import numpy as np +import torch +import MinkowskiEngine as ME +import open3d as o3d + +if __name__ == '__main__': + import sys + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utilities.data_utils import pc_normalize, joints_normalize, farthest_point_sample, generate_target_tr, generate_target_rot, transform_pc, transform_dir +from utilities.vis_utils import visualize +from utilities.env_utils import setup_seed +from utilities.constants import light_blue_color, red_color, dark_red_color, dark_green_color +from src_shot.build import shot + + +class ArticulationDataset(torch.utils.data.Dataset): + def __init__(self, path:str, instances:List[str], joint_num:int, resolution:float, receptive_field:int, + sample_points_num:int, sample_tuples_num:int, tuple_more_num:int, + rgb:bool, denoise:bool, normalize:str, debug:bool, vis:bool, is_train:bool) -> None: + super().__init__() + self.path = path + self.instances = instances + self.joint_num = joint_num + self.resolution = resolution + self.receptive_field = receptive_field + self.sample_points_num = sample_points_num + self.sample_tuples_num = sample_tuples_num + self.tuple_more_num = tuple_more_num + self.rgb = rgb + self.denoise = denoise + self.normalize = normalize + self.debug = debug + self.vis = vis + self.is_train = is_train + self.fns = sorted(list(itertools.chain(*[list(Path(path).glob('{}/*.npz'.format(instance))) for instance in instances]))) + self.permutations = list(itertools.permutations(range(self.sample_points_num), 2)) + if debug: + print(f"{len(self.fns) =}", f"{self.fns[0] =}") + + def __len__(self): + return len(self.fns) + + def __getitem__(self, idx:int): + # load data + data = np.load(self.fns[idx]) + + pc = data['point_cloud'].astype(np.float32) # (N'', 3) + if self.debug: + print(f"{pc.shape = }") + print(f"{idx = }", f"{self.fns[idx] = }") + + if self.rgb: + pc_color = data['rgb'].astype(np.float32) # (N'', 3) + + assert data['joints'].shape[0] == self.joint_num + joints = data['joints'].astype(np.float32) # (J, 9) + if self.debug: + print(f"{joints.shape = }") + + joint_translations = joints[:, 0:3].astype(np.float32) # (J, 3) + joint_rotations = joints[:, 3:6].astype(np.float32) # (J, 3) + affordable_positions = joints[:, 6:9].astype(np.float32) # (J, 3) + joint_types = joints[:, -1].astype(np.int64) # (J,) + if self.debug: + print(f"{joint_translations.shape = }", f"{joint_rotations.shape = }", f"{affordable_positions.shape = }") + + # TODO: transform to sapien coordinate, not compatible with pybullet generated data + transform_matrix = np.array([[0, 0, 1, 0], + [-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 0, 1]]) + pc = transform_pc(pc, transform_matrix) + joint_translations = transform_pc(joint_translations, transform_matrix) + joint_rotations = transform_dir(joint_rotations, transform_matrix) + affordable_positions = transform_pc(affordable_positions, transform_matrix) + + if self.debug and self.vis: + # camera coordinate as (x, y, z) = (right, down, in) + # normal as inside/outside both + visualize(pc, pc_color=pc_color if self.rgb else light_blue_color, pc_normal=None, + joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions, + joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color, + whether_frame=True, whether_bbox=True, window_name='before') + + # preprocess + suffix = 'cache' + if self.denoise: + suffix += '_denoise' + suffix += f'_{self.normalize}' + suffix += f'_{str(self.resolution)}' + suffix += f'_{str(self.receptive_field)}' + suffix += f'_{str(self.sample_points_num)}' + preprocessed_path = os.path.join(os.path.dirname(str(self.fns[idx])), suffix, os.path.basename(str(self.fns[idx]))) + if os.path.exists(preprocessed_path): + preprocessed_data = np.load(preprocessed_path) + pc = preprocessed_data['pc'].astype(np.float32) + pc_normal = preprocessed_data['pc_normal'].astype(np.float32) + pc_shot = preprocessed_data['pc_shot'].astype(np.float32) + if self.rgb: + pc_color = preprocessed_data['pc_color'].astype(np.float32) + center = preprocessed_data['center'].astype(np.float32) + scale = float(preprocessed_data['scale']) + joint_translations, joint_rotations = joints_normalize(joint_translations, joint_rotations, center, scale) + affordable_positions, _ = joints_normalize(affordable_positions, None, center, scale) + else: + start_time = time.time() + if self.denoise: + # pcd = o3d.geometry.PointCloud() + # pcd.points = o3d.utility.Vector3dVector(pc) + # _, index = pcd.remove_statistical_outlier(nb_neighbors=100, std_ratio=1.5) + # pc = pc[index] + # if self.rgb: + # pc_color = pc_color[index] + valid_mask = pc[:, 0] > 0.05 + pc = pc[valid_mask] + if self.rgb: + pc_color = pc_color[valid_mask] + end_time = time.time() + if self.debug: + print(f"denoise: {end_time - start_time}") + print(f"{pc.shape = }") + + start_time = time.time() + pc, center, scale = pc_normalize(pc, self.normalize) + joint_translations, joint_rotations = joints_normalize(joint_translations, joint_rotations, center, scale) + affordable_positions, _ = joints_normalize(affordable_positions, None, center, scale) + end_time = time.time() + if self.debug: + print(f"pc_normalize: {end_time - start_time}") + + start_time = time.time() + indices = ME.utils.sparse_quantize(np.ascontiguousarray(pc), return_index=True, quantization_size=self.resolution)[1] + pc = np.ascontiguousarray(pc[indices].astype(np.float32)) # (N', 3) + if self.rgb: + pc_color = pc_color[indices] # (N', 3) + end_time = time.time() + if self.debug: + print(f"sparse_quantize: {end_time - start_time}") + print(f"{pc.shape = }") + + start_time = time.time() + pc_normal = shot.estimate_normal(pc, self.resolution * self.receptive_field).reshape(-1, 3).astype(np.float32) + pc_normal[~np.isfinite(pc_normal)] = 0 # (N', 3) + end_time = time.time() + if self.debug: + print(f"estimate_normal: {end_time - start_time}") + print(f"{pc_normal.shape = }") + + start_time = time.time() + pc_shot = shot.compute(pc, self.resolution * self.receptive_field, self.resolution * self.receptive_field).reshape(-1, 352).astype(np.float32) + pc_shot[~np.isfinite(pc_shot)] = 0 # (N', 352) + end_time = time.time() + if self.debug: + print(f"shot: {end_time - start_time}") + print(f"{pc_shot.shape = }") + + start_time = time.time() + pc, indices = farthest_point_sample(pc, self.sample_points_num) # (N, 3) + pc_normal = pc_normal[indices] # (N, 3) + pc_shot = pc_shot[indices] # (N, 352) + if self.rgb: + pc_color = pc_color[indices] # (N, 3) + end_time = time.time() + if self.debug: + print(f"farthest_point_sample: {end_time - start_time}") + print(f"{pc.shape = }", f"{pc_normal.shape = }", f"{pc_shot.shape = }") + + if not self.debug: + os.makedirs(os.path.dirname(preprocessed_path), exist_ok=True) + if self.rgb: + np.savez(preprocessed_path, pc=pc, pc_normal=pc_normal, pc_shot=pc_shot, pc_color=pc_color, center=center, scale=scale) + else: + np.savez(preprocessed_path, pc=pc, pc_normal=pc_normal, pc_shot=pc_shot, center=center, scale=scale) + + if self.debug and self.vis: + # camera coordinate as (x, y, z) = (right, down, in) + # normal as inside/outside both + visualize(pc, pc_color=pc_color if self.rgb else light_blue_color, pc_normal=pc_normal, + joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions, + joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color, + whether_frame=True, whether_bbox=True, window_name='after') + + # sample point tuples + start_time = time.time() + point_idxs = random.sample(self.permutations, self.sample_tuples_num) + point_idxs = np.array(point_idxs, dtype=np.int64) # (N_t, 2) + point_idxs_more = np.random.randint(0, self.sample_points_num, size=(self.sample_tuples_num, self.tuple_more_num), dtype=np.int64) # (N_t, N_m) + point_idxs_all = np.concatenate([point_idxs, point_idxs_more], axis=-1) # (N_t, 2 + N_m) + end_time = time.time() + if self.debug: + print(f"sample_point_tuples: {end_time - start_time}") + print(f"{point_idxs_all.shape = }") + + # generate targets + start_time = time.time() + targets_tr, targets_rot = [], [] + for j in range(self.joint_num): + target_tr = generate_target_tr(pc, joint_translations[j], point_idxs_all[:, :2]) # (N_t, 2) + target_rot = generate_target_rot(pc, joint_rotations[j], point_idxs_all[:, :2]) # (N_t,) + targets_tr.append(target_tr) + targets_rot.append(target_rot) + targets_tr = np.stack(targets_tr, axis=0).astype(np.float32) # (J, N_t, 2) + targets_rot = np.stack(targets_rot, axis=0).astype(np.float32) # (J, N_t) + end_time = time.time() + if self.debug: + print(f"generate_targets: {end_time - start_time}") + print(f"{targets_tr.shape = }", f"{targets_rot.shape = }") + + if self.is_train: + if self.rgb: + return pc, pc_normal, pc_shot, pc_color, \ + targets_tr, targets_rot, \ + point_idxs_all + else: + return pc, pc_normal, pc_shot, \ + targets_tr, targets_rot, \ + point_idxs_all + else: + if self.rgb: + return pc, pc_normal, pc_shot, pc_color, \ + joint_translations, joint_rotations, affordable_positions, joint_types, \ + point_idxs_all, str(self.fns[idx]) + else: + return pc, pc_normal, pc_shot, \ + joint_translations, joint_rotations, affordable_positions, joint_types, \ + point_idxs_all, str(self.fns[idx]) + + +if __name__ == '__main__': + setup_seed(seed=42) + path = "/data2/junbo/RealArt-6/without_table/microwave" + instances = ['0_without_chaos', '1_without_chaos', '2_without_chaos', '3_without_chaos', '4_without_chaos'] + joint_num = 1 + resolution = 1e-2 + receptive_field = 10 + sample_points_num = 1024 + sample_tuples_num = 100000 + tuple_more_num = 3 + normalize = 'none' + rgb = False + denoise = True + dataset = ArticulationDataset(path, instances, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + rgb, denoise, normalize, debug=True, vis=True, is_train=False) + batch_size = 2 + num_workers = 0 + dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers) + for results in tqdm.tqdm(dataloader): + import pdb; pdb.set_trace() diff --git a/datasets/rconfmask_afford_point_tuple_dataset.py b/datasets/rconfmask_afford_point_tuple_dataset.py new file mode 100644 index 0000000..1c00c6d --- /dev/null +++ b/datasets/rconfmask_afford_point_tuple_dataset.py @@ -0,0 +1,260 @@ +from typing import List +import os +import time +import itertools +from pathlib import Path +import tqdm +import random +import numpy as np +import torch +import MinkowskiEngine as ME +import open3d as o3d + +if __name__ == '__main__': + import sys + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utilities.data_utils import transform_pc, transform_dir, pc_normalize, joints_normalize, farthest_point_sample, generate_target_tr, generate_target_rot, pc_noise +from utilities.vis_utils import visualize_mask, visualize_confidence_voting +from utilities.env_utils import setup_seed +from src_shot.build import shot + + +class ArticulationDataset(torch.utils.data.Dataset): + def __init__(self, path:str, categories:List[str], joint_num:int, resolution:float, receptive_field:int, + sample_points_num:int, sample_tuples_num:int, tuple_more_num:int, + noise:bool, distortion_rate:float, distortion_level:float, outlier_rate:float, outlier_level:float, + rgb:bool, denoise:bool, normalize:str, debug:bool, vis:bool, is_train:bool) -> None: + super().__init__() + self.path = path + self.categories = categories + self.joint_num = joint_num + self.resolution = resolution + self.receptive_field = receptive_field + self.sample_points_num = sample_points_num + self.sample_tuples_num = sample_tuples_num + self.tuple_more_num = tuple_more_num + self.noise = noise # NOTE: only used in testing + self.distortion_rate = distortion_rate + self.distortion_level = distortion_level + self.outlier_rate = outlier_rate + self.outlier_level = outlier_level + self.rgb = rgb + self.denoise = denoise + self.normalize = normalize + self.debug = debug + self.vis = vis + self.is_train = is_train + self.fns = sorted(list(itertools.chain(*[list(Path(path).glob('{}*/*.npz'.format(category))) for category in categories]))) + self.permutations = list(itertools.permutations(range(self.sample_points_num), 2)) + if debug: + print(f"{len(self.fns) =}", f"{self.fns[0] =}") + + def __len__(self): + return len(self.fns) + + def __getitem__(self, idx:int): + # load data + data = np.load(self.fns[idx]) + + c2w = data['extrinsic'].astype(np.float32) # (4, 4) + w2c = np.linalg.inv(c2w) + pc = data['pcd_world'].astype(np.float32) # (N'', 3) + pc = transform_pc(pc, w2c) + if self.noise: + pc = pc_noise(pc, self.distortion_rate, self.distortion_level, self.outlier_rate, self.outlier_level) + if self.debug: + print(f"{pc.shape = }") + print(f"{idx = }", f"{self.fns[idx] = }") + + if self.rgb: + pc_color = data['pcd_color'].astype(np.float32) # (N'', 3) + + instance_mask = data['instance_mask'].astype(np.int64) # (N'',) + function_mask = data['function_mask'].astype(np.int64) # (N'',) + + joint_translations = data['joint_bases'].astype(np.float32) # (J, 3) + joint_translations = transform_pc(joint_translations, w2c) + joint_rotations = data['joint_directions'].astype(np.float32) # (J, 3) + joint_rotations = transform_dir(joint_rotations, w2c) + affordable_positions = data['affordable_positions'].astype(np.float32) # (J, 3) + affordable_positions = transform_pc(affordable_positions, w2c) + joint_num = joint_translations.shape[0] + assert self.joint_num == joint_num + if self.debug: + print(f"{joint_translations.shape = }", f"{joint_rotations.shape = }", f"{affordable_positions.shape = }") + + if self.debug and self.vis: + # camera coordinate as (x, y, z) = (in, left, up) + # normal as inside/outside both + visualize_mask(pc, instance_mask, function_mask, pc_normal=None, + joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions, + whether_frame=True, whether_bbox=True, window_name='before') + + # preprocess + start_time = time.time() + if self.denoise: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + _, index = pcd.remove_statistical_outlier(nb_neighbors=100, std_ratio=1.5) + pc = pc[index] + if self.rgb: + pc_color = pc_color[index] + end_time = time.time() + if self.debug: + print(f"denoise: {end_time - start_time}") + print(f"{pc.shape = }") + + start_time = time.time() + pc, center, scale = pc_normalize(pc, self.normalize) + joint_translations, joint_rotations = joints_normalize(joint_translations, joint_rotations, center, scale) + affordable_positions, _ = joints_normalize(affordable_positions, None, center, scale) + end_time = time.time() + if self.debug: + print(f"pc_normalize: {end_time - start_time}") + + start_time = time.time() + indices = ME.utils.sparse_quantize(np.ascontiguousarray(pc), return_index=True, quantization_size=self.resolution)[1] + pc = np.ascontiguousarray(pc[indices].astype(np.float32)) # (N', 3) + instance_mask = instance_mask[indices] # (N',) + function_mask = function_mask[indices] # (N',) + if self.rgb: + pc_color = pc_color[indices] # (N', 3) + end_time = time.time() + if self.debug: + print(f"sparse_quantize: {end_time - start_time}") + print(f"{pc.shape = }") + + start_time = time.time() + pc_normal = shot.estimate_normal(pc, self.resolution * self.receptive_field).reshape(-1, 3).astype(np.float32) + pc_normal[~np.isfinite(pc_normal)] = 0 # (N', 3) + end_time = time.time() + if self.debug: + print(f"estimate_normal: {end_time - start_time}") + print(f"{pc_normal.shape = }") + + start_time = time.time() + pc_shot = shot.compute(pc, self.resolution * self.receptive_field, self.resolution * self.receptive_field).reshape(-1, 352).astype(np.float32) + pc_shot[~np.isfinite(pc_shot)] = 0 # (N', 352) + end_time = time.time() + if self.debug: + print(f"shot: {end_time - start_time}") + print(f"{pc_shot.shape = }") + + start_time = time.time() + pc, indices = farthest_point_sample(pc, self.sample_points_num) # (N, 3) + pc_normal = pc_normal[indices] # (N, 3) + pc_shot = pc_shot[indices] # (N, 352) + if self.rgb: + pc_color = pc_color[indices] # (N, 3) + instance_mask = instance_mask[indices] # (N,) + function_mask = function_mask[indices] # (N,) + end_time = time.time() + if self.debug: + print(f"farthest_point_sample: {end_time - start_time}") + print(f"{pc.shape = }", f"{pc_normal.shape = }", f"{pc_shot.shape = }") + + if self.debug and self.vis: + # camera coordinate as (x, y, z) = (in, left, up) + # normal as inside/outside both + visualize_mask(pc, instance_mask, function_mask, pc_normal=pc_normal, + joint_translations=joint_translations, joint_rotations=joint_rotations, affordable_positions=affordable_positions, + whether_frame=True, whether_bbox=True, window_name='after') + + # sample point tuples + start_time = time.time() + point_idxs = random.sample(self.permutations, self.sample_tuples_num) + point_idxs = np.array(point_idxs, dtype=np.int64) # (N_t, 2) + point_idxs_more = np.random.randint(0, self.sample_points_num, size=(self.sample_tuples_num, self.tuple_more_num), dtype=np.int64) # (N_t, N_m) + point_idxs_all = np.concatenate([point_idxs, point_idxs_more], axis=-1) # (N_t, 2 + N_m) + base_mask = (instance_mask == 0) + base_idxs = np.where(base_mask)[0] + part_idxs = [] + for j in range(joint_num): + part_mask = (instance_mask == j + 1) + part_idxs.append(np.where(part_mask)[0]) + base_idxs_mask = np.isin(point_idxs, base_idxs) + base_case_mask = np.all(base_idxs_mask, axis=-1) + part_base_case_mask = np.logical_and(np.any(base_idxs_mask, axis=-1), np.logical_not(base_case_mask)) + part_case_mask = np.logical_not(np.logical_or(base_case_mask, part_base_case_mask)) + end_time = time.time() + if self.debug: + print(f"sample_point_tuples: {end_time - start_time}") + print(f"{point_idxs.shape = }, {point_idxs_all.shape = }") + if self.debug and self.vis: + visualize_confidence_voting(np.ones((self.sample_tuples_num,)), pc, point_idxs, whether_frame=True, whether_bbox=True, window_name='point tuples') + + # generate targets + start_time = time.time() + targets_tr = np.zeros((joint_num, self.sample_tuples_num, 2), dtype=np.float32) # (J, N_t, 2) + targets_rot = np.zeros((joint_num, self.sample_tuples_num), dtype=np.float32) # (J, N_t) + targets_afford = np.zeros((joint_num, self.sample_tuples_num, 2), dtype=np.float32) # (J, N_t, 2) + targets_conf = np.zeros((joint_num, self.sample_tuples_num), dtype=np.float32) # (J, N_t) + for j in range(joint_num): + this_part_mask = np.any(np.isin(point_idxs[part_base_case_mask], part_idxs[j]), axis=-1) + same_part_mask = np.all(np.isin(point_idxs[part_case_mask], part_idxs[j]), axis=-1) + merge_this_part_mask = part_base_case_mask.copy() + merge_this_part_mask[part_base_case_mask] = this_part_mask.copy() + merge_same_part_mask = part_case_mask.copy() + merge_same_part_mask[part_case_mask] = same_part_mask.copy() + targets_tr[j] = generate_target_tr(pc, joint_translations[j], point_idxs) + targets_rot[j] = generate_target_rot(pc, joint_rotations[j], point_idxs) + targets_afford[j] = generate_target_tr(pc, affordable_positions[j], point_idxs) + # targets_conf[j, merge_same_part_mask] = 0.51 + targets_conf[j, merge_this_part_mask] = 1.0 + end_time = time.time() + if self.debug: + print(f"generate_targets: {end_time - start_time}") + print(f"{targets_tr.shape = }", f"{targets_rot.shape = }", f"{targets_conf.shape = }", f"{targets_afford.shape = }") + + if self.is_train: + if self.rgb: + return pc, pc_normal, pc_shot, pc_color, \ + targets_tr, targets_rot, targets_afford, targets_conf, \ + point_idxs_all + else: + return pc, pc_normal, pc_shot, \ + targets_tr, targets_rot, targets_afford, targets_conf, \ + point_idxs_all + else: + # actually during testing, the targets should not be known, here they are used to test gt + if self.rgb: + return pc, pc_normal, pc_shot, pc_color, \ + joint_translations, joint_rotations, affordable_positions, \ + targets_tr, targets_rot, targets_afford, targets_conf, \ + point_idxs_all + else: + return pc, pc_normal, pc_shot, \ + joint_translations, joint_rotations, affordable_positions, \ + targets_tr, targets_rot, targets_afford, targets_conf, \ + point_idxs_all + + +if __name__ == '__main__': + setup_seed(seed=42) + path = "/data2/junbo/sapien4/test" + categories = ['Microwave'] + joint_num = 1 + # resolution = 2.5e-2 + resolution = 1e-2 + receptive_field = 10 + sample_points_num = 1024 + sample_tuples_num = 100000 + tuple_more_num = 3 + # normalize = 'median' + normalize = 'none' + noise = True + distortion_rate = 0.1 + distortion_level = 0.01 + outlier_rate = 0.001 + outlier_level = 1.0 + rgb = True + denoise = False + dataset = ArticulationDataset(path, categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + noise, distortion_rate, distortion_level, outlier_rate, outlier_level, + rgb, denoise, normalize, debug=True, vis=False, is_train=True) + batch_size = 1 + num_workers = 0 + dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers) + for results in tqdm.tqdm(dataloader): + import pdb; pdb.set_trace() diff --git a/deps/.gitkeep b/deps/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/envs/camera.py b/envs/camera.py new file mode 100644 index 0000000..a2c229a --- /dev/null +++ b/envs/camera.py @@ -0,0 +1,209 @@ +""" +Modified from https://github.com/warshallrho/VAT-Mart/blob/main/code/camera.py +""" +import numpy as np +from sapien.core import Pose + +from .env import Env + +class Camera(object): + + def __init__(self, env:Env, near=0.1, far=100.0, image_size=448, dist=5.0, \ + phi=np.pi/5, theta=np.pi, fov=35.0, random_position=False, fixed_position=False, restrict_dir=False, real_data=False): + builder = env.scene.create_actor_builder() + camera_mount_actor = builder.build(is_kinematic=True) + self.env = env + + # set camera intrinsics + if isinstance(image_size, int): + width = image_size + height = image_size + else: + width = image_size[0] + height = image_size[1] + if isinstance(fov, float): + fovx = 0 + fovy = fov + else: + fovx = fov[0] + fovy = fov[1] + self.camera = env.scene.add_mounted_camera('camera', camera_mount_actor, Pose(), \ + width, height, np.deg2rad(fovx), np.deg2rad(fovy), near, far) + + # set camera extrinsics + if random_position: + if restrict_dir: + theta = (0.25 + np.random.random() * 0.5) * np.pi*2 + phi = (np.random.random()+1) * np.pi/6 + else: + theta = np.random.random() * np.pi*2 + phi = (np.random.random()+1) * np.pi/6 + if fixed_position: + theta = np.pi + phi = np.pi/10 + if random_position and real_data: + theta = np.random.random() * np.pi*2 + phi = np.random.random() * np.pi * 2 + + pos = np.array([dist*np.cos(phi)*np.cos(theta), \ + dist*np.cos(phi)*np.sin(theta), \ + dist*np.sin(phi)]) + forward = -pos / np.linalg.norm(pos) + left = np.cross([0, 0, 1], forward) + left = left / np.linalg.norm(left) + up = np.cross(forward, left) + mat44 = np.eye(4) + mat44[:3, :3] = np.vstack([forward, left, up]).T + mat44[:3, 3] = pos # mat44 is cam2world + mat44[0, 3] += env.object_position_offset + self.mat44 = mat44 + camera_mount_actor.set_pose(Pose.from_transformation_matrix(mat44)) + + # log parameters + self.near = near + self.far = far + self.fov = [fovx, fovy] + self.dist = dist + self.theta = theta + self.phi = phi + self.pos = pos + + self.camera_mount_actor = camera_mount_actor + + def change_pose(self, + phi=np.pi/5, theta=np.pi, dist=5.0, random_position=False, restrict_dir=False): + # set camera extrinsics + if random_position: + if restrict_dir: + theta = (0.25 + np.random.random() * 0.5) * np.pi*2 + phi = (np.random.random()+1) * np.pi/6 + else: + theta = np.random.random() * np.pi*2 + phi = (np.random.random()+1) * np.pi/6 + pos = np.array([dist*np.cos(phi)*np.cos(theta), \ + dist*np.cos(phi)*np.sin(theta), \ + dist*np.sin(phi)]) + forward = -pos / np.linalg.norm(pos) + left = np.cross([0, 0, 1], forward) + left = left / np.linalg.norm(left) + up = np.cross(forward, left) + mat44 = np.eye(4) + mat44[:3, :3] = np.vstack([forward, left, up]).T + mat44[:3, 3] = pos # mat44 is cam2world + mat44[0, 3] += self.env.object_position_offset + self.mat44 = mat44 + self.camera_mount_actor.set_pose(Pose.from_transformation_matrix(mat44)) + + self.dist = dist + self.theta = theta + self.phi = phi + self.pos = pos + + def change_fov(self, fov): + if isinstance(fov, float): + fovx = 0 + fovy = fov + else: + fovx = fov[0] + fovy = fov[1] + self.camera.set_mode_perspective(np.deg2rad(fovy)) + self.fov = [fovx, fovy] + + def get_observation(self): + self.camera.take_picture() + rgba = self.camera.get_color_rgba() + rgba = (rgba * 255).clip(0, 255).astype(np.float32) / 255 + white = np.ones((rgba.shape[0], rgba.shape[1], 3), dtype=np.float32) + mask = np.tile(rgba[:, :, 3:4], [1, 1, 3]) + rgb = rgba[:, :, :3] * mask + white * (1 - mask) + depth = self.camera.get_depth().astype(np.float32) + return rgb, depth + + def compute_camera_XYZA(self, depth): + camera_matrix = self.camera.get_camera_matrix()[:3, :3] + y, x = np.where(depth < 1) + z = self.near * self.far / (self.far + depth * (self.near - self.far)) + permutation = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) + points = (permutation @ np.dot(np.linalg.inv(camera_matrix), \ + np.stack([x, y, np.ones_like(x)] * z[y, x], 0))).T + return y, x, points + + @staticmethod + def compute_XYZA_matrix(id1, id2, pts, size1, size2): + out = np.zeros((size1, size2, 4), dtype=np.float32) + out[id1, id2, :3] = pts + out[id1, id2, 3] = 1 + return out + + def get_normal_map(self): + nor = self.camera.get_normal_rgba() + # convert from PartNet-space (x-right, y-up, z-backward) to SAPIEN-space (x-front, y-left, z-up) + new_nor = np.array(nor, dtype=np.float32) + new_nor[:, :, 0] = -nor[:, :, 2] + new_nor[:, :, 1] = -nor[:, :, 0] + new_nor[:, :, 2] = nor[:, :, 1] + return new_nor + + def get_movable_link_mask(self, link_ids): + link_seg = self.camera.get_segmentation() + link_mask = np.zeros((link_seg.shape[0], link_seg.shape[1])).astype(np.uint8) + for idx, lid in enumerate(link_ids): + cur_link_pixels = int(np.sum(link_seg==lid)) + if cur_link_pixels > 0: + link_mask[link_seg == lid] = idx+1 + return link_mask + + def get_handle_mask(self): + # read part seg partid2renderids + partid2renderids = dict() + for k in self.env.scene.render_id_to_visual_name: + if self.env.scene.render_id_to_visual_name[k].split('-')[0] == 'handle': + part_id = int(self.env.scene.render_id_to_visual_name[k].split('-')[-1]) + if part_id not in partid2renderids: + partid2renderids[part_id] = [] + partid2renderids[part_id].append(k) + # generate 0/1 handle mask + part_seg = self.camera.get_obj_segmentation() + handle_mask = np.zeros((part_seg.shape[0], part_seg.shape[1])).astype(np.uint8) + for partid in partid2renderids: + cur_part_mask = np.isin(part_seg, partid2renderids[partid]) + cur_part_mask_pixels = int(np.sum(cur_part_mask)) + if cur_part_mask_pixels > 0: + handle_mask[cur_part_mask] = 1 + return handle_mask + + def get_object_mask(self): + rgba = self.camera.get_albedo_rgba() + return rgba[:, :, 3] > 0.5 + + # return camera parameters + def get_metadata(self): + return { + 'pose': self.camera.get_pose(), + 'near': self.camera.get_near(), + 'far': self.camera.get_far(), + 'width': self.camera.get_width(), + 'height': self.camera.get_height(), + 'fov': self.camera.get_fovy(), + 'camera_matrix': self.camera.get_camera_matrix(), + 'projection_matrix': self.camera.get_projection_matrix(), + 'model_matrix': self.camera.get_model_matrix(), + 'mat44': self.mat44, + } + + # return camera parameters + def get_metadata_json(self): + return { + 'dist': self.dist, + 'theta': self.theta, + 'phi': self.phi, + 'near': self.camera.get_near(), + 'far': self.camera.get_far(), + 'width': self.camera.get_width(), + 'height': self.camera.get_height(), + 'fov': self.camera.get_fovy(), + 'camera_matrix': self.camera.get_camera_matrix().tolist(), + 'projection_matrix': self.camera.get_projection_matrix().tolist(), + 'model_matrix': self.camera.get_model_matrix().tolist(), + 'mat44': self.mat44.tolist(), + } diff --git a/envs/env.py b/envs/env.py new file mode 100644 index 0000000..d71fbd4 --- /dev/null +++ b/envs/env.py @@ -0,0 +1,628 @@ +""" +Modified from https://github.com/warshallrho/VAT-Mart/blob/main/code/env.py +""" + +from __future__ import division +import sapien.core as sapien +from sapien.core import Pose, SceneConfig, OptifuserConfig, ArticulationJointType +import numpy as np +import trimesh + + +def process_angle_limit(x): + if np.isneginf(x): + x = -10 + if np.isinf(x): + x = 10 + return x + +def get_random_number(l, r): + return np.random.rand() * (r - l) + l + + +class ContactError(Exception): + pass + + +class SVDError(Exception): + pass + + +class Env(object): + + def __init__(self, flog=None, show_gui=True, render_rate=20, timestep=1/500, \ + object_position_offset=0.0, succ_ratio=0.1): + self.current_step = 0 + + self.flog = flog + self.show_gui = show_gui + self.render_rate = render_rate + self.timestep = timestep + self.succ_ratio = succ_ratio + self.object_position_offset = object_position_offset + + # engine and renderer + self.engine = sapien.Engine(0, 0.001, 0.005) + + render_config = OptifuserConfig() + render_config.shadow_map_size = 8192 + render_config.shadow_frustum_size = 10 + render_config.use_shadow = False + render_config.use_ao = True + + self.renderer = sapien.OptifuserRenderer(config=render_config) + self.renderer.enable_global_axes(False) + + self.engine.set_renderer(self.renderer) + + # GUI + self.window = False + if show_gui: + self.renderer_controller = sapien.OptifuserController(self.renderer) + self.renderer_controller.set_camera_position(-3.0+object_position_offset, 1.0, 3.0) + self.renderer_controller.set_camera_rotation(-0.4, -0.8) + + # scene + scene_config = SceneConfig() + scene_config.gravity = [0, 0, 0] + scene_config.solver_iterations = 20 + scene_config.enable_pcm = False + scene_config.sleep_threshold = 0.0 + + self.scene = self.engine.create_scene(config=scene_config) + if show_gui: + self.renderer_controller.set_current_scene(self.scene) + + self.scene.set_timestep(timestep) + + # add lights + self.scene.set_ambient_light([0.5, 0.5, 0.5]) + self.scene.set_shadow_light([0, 1, -1], [0.5, 0.5, 0.5]) + self.scene.add_point_light([1+object_position_offset, 2, 2], [1, 1, 1]) + self.scene.add_point_light([1+object_position_offset, -2, 2], [1, 1, 1]) + self.scene.add_point_light([-1+object_position_offset, 0, 1], [1, 1, 1]) + + # default Nones + self.object = None + self.object_target_joint = None + + # check contact + self.check_contact = False + self.contact_error = False + + # visual objects + self.visual_builder = self.scene.create_actor_builder() + self.visual_objects = dict() + + def set_controller_camera_pose(self, x, y, z, yaw, pitch): + self.renderer_controller.set_camera_position(x, y, z) + self.renderer_controller.set_camera_rotation(yaw, pitch) + self.renderer_controller.render() + + def load_object(self, urdf, material, state='closed', target_part_id=-1, target_part_idx=-1, scale=1.0): + # NOTE: set target_part_idx only set other joints to nearly closed, will not track the target joint + loader = self.scene.create_urdf_loader() + loader.scale = scale + self.object = loader.load(urdf, {"material": material}) + pose = Pose([self.object_position_offset, 0, 0], [1, 0, 0, 0]) + self.object.set_root_pose(pose) + + # compute link actor information + self.all_link_ids = [l.get_id() for l in self.object.get_links()] + self.all_link_names = [l.get_name() for l in self.object.get_links()] + self.movable_link_ids = [] + self.movable_joint_idxs = [] + self.movable_link_joint_types = [] + self.movable_link_joint_names = [] + + for j_idx, j in enumerate(self.object.get_joints()): + if j.get_dof() == 1: + if j.type == ArticulationJointType.REVOLUTE: + self.movable_link_joint_types.append(0) + if j.type == ArticulationJointType.PRISMATIC: + self.movable_link_joint_types.append(1) + self.movable_link_joint_names.append(j.get_name()) + + self.movable_link_ids.append(j.get_child_link().get_id()) + self.movable_joint_idxs.append(j_idx) + if self.flog is not None: + self.flog.write('All Actor Link IDs: %s\n' % str(self.all_link_ids)) + self.flog.write('All Movable Actor Link IDs: %s\n' % str(self.movable_link_ids)) + + # set joint property + for joint in self.object.get_joints(): + joint.set_drive_property(stiffness=0, damping=10) + + # set initial qpos + joint_angles = [] + joint_abs_angles = [] + self.joint_angles_lower = [] + self.joint_angles_upper = [] + target_part_joint_idx = -1 + joint_idx = 0 + for j in self.object.get_joints(): + if j.get_dof() == 1: + if j.get_child_link().get_id() == target_part_id: + target_part_joint_idx = len(joint_angles) + l = process_angle_limit(j.get_limits()[0, 0]) + self.joint_angles_lower.append(float(l)) + r = process_angle_limit(j.get_limits()[0, 1]) + self.joint_angles_upper.append(float(r)) + if state == 'closed': + joint_angles.append(float(l)) + elif state == 'open': + joint_angles.append(float(r)) + elif state == 'random-middle': + joint_angles.append(float(get_random_number(l, r))) + elif state == 'random-closed-middle': + if np.random.random() < 0.5: + joint_angles.append(float(get_random_number(l, r))) + else: + joint_angles.append(float(l)) + elif state == 'random-middle-middle': + if joint_idx == target_part_idx: + joint_angles.append(float(get_random_number(l + 0.15*(r-l), r - 0.15*(r-l)))) + else: + joint_angles.append(float(get_random_number(l, l + 0.0*(r-l)))) + else: + raise ValueError('ERROR: object init state %s unknown!' % state) + joint_abs_angles.append((joint_angles[-1]-l)/(r-l)) + joint_idx += 1 + + self.object.set_qpos(joint_angles) + if target_part_id >= 0: + return joint_angles, target_part_joint_idx, joint_abs_angles + return joint_angles, joint_abs_angles + + def load_real_object(self, urdf, material, joint_angles=None): + loader = self.scene.create_urdf_loader() + self.object = loader.load(urdf, {"material": material}) + pose = Pose([self.object_position_offset, 0, 0], [1, 0, 0, 0]) + self.object.set_root_pose(pose) + + # compute link actor information + self.all_link_ids = [l.get_id() for l in self.object.get_links()] + self.movable_link_ids = [] + for j in self.object.get_joints(): + if j.get_dof() == 1: + self.movable_link_ids.append(j.get_child_link().get_id()) + if self.flog is not None: + self.flog.write('All Actor Link IDs: %s\n' % str(self.all_link_ids)) + self.flog.write('All Movable Actor Link IDs: %s\n' % str(self.movable_link_ids)) + + # set joint property + for joint in self.object.get_joints(): + joint.set_drive_property(stiffness=0, damping=10) + + if joint_angles is not None: + self.object.set_qpos(joint_angles) + + return None + + def update_and_set_joint_angles_all(self, state='closed'): + joint_angles = [] + for j in self.object.get_joints(): + if j.get_dof() == 1: + l = process_angle_limit(j.get_limits()[0, 0]) + self.joint_angles_lower.append(float(l)) + r = process_angle_limit(j.get_limits()[0, 1]) + self.joint_angles_upper.append(float(r)) + if state == 'closed': + joint_angles.append(float(l)) + elif state == 'open': + joint_angles.append(float(r)) + elif state == 'random-middle': + joint_angles.append(float(get_random_number(l, r))) + elif state == 'random-closed-middle': + if np.random.random() < 0.5: + joint_angles.append(float(get_random_number(l, r))) + else: + joint_angles.append(float(l)) + else: + raise ValueError('ERROR: object init state %s unknown!' % state) + self.object.set_qpos(joint_angles) + return joint_angles + + def get_target_part_axes_new(self, target_part_id): + joint_axes = None + for j in self.object.get_joints(): + if j.get_dof() == 1: + if j.get_child_link().get_id() == target_part_id: + pos = j.get_global_pose() + mat = pos.to_transformation_matrix() + joint_axes = [float(-mat[0, 0]), float(-mat[1, 0]), float(mat[2, 0])] + if joint_axes is None: + raise ValueError('joint axes error!') + + return joint_axes + + def get_target_part_axes_dir_new(self, target_part_id): + joint_axes = self.get_target_part_axes_new(target_part_id=target_part_id) + axes_dir = -1 + for idx_axes_dim in range(3): + if abs(joint_axes[idx_axes_dim]) > 0.1: + axes_dir = idx_axes_dim + return axes_dir + + def get_target_part_origins_new(self, target_part_id): + joint_origins = None + for j in self.object.get_joints(): + if j.get_dof() == 1: + if j.get_child_link().get_id() == target_part_id: + pos = j.get_global_pose() + joint_origins = pos.p.tolist() + if joint_origins is None: + raise ValueError('joint origins error!') + + return joint_origins + + def update_joint_angle(self, joint_angles, target_part_joint_idx, state, task_lower, push=True, pull=False, drawer=False): + if push: + if drawer: + l = max(self.joint_angles_lower[target_part_joint_idx], self.joint_angles_lower[target_part_joint_idx] + task_lower) + r = self.joint_angles_upper[target_part_joint_idx] + else: + l = max(self.joint_angles_lower[target_part_joint_idx], self.joint_angles_lower[target_part_joint_idx] + task_lower * np.pi / 180) + r = self.joint_angles_upper[target_part_joint_idx] + if pull: + if drawer: + l = self.joint_angles_lower[target_part_joint_idx] + r = self.joint_angles_upper[target_part_joint_idx] - task_lower + else: + l = self.joint_angles_lower[target_part_joint_idx] + r = self.joint_angles_upper[target_part_joint_idx] - task_lower * np.pi / 180 + if state == 'closed': + joint_angles[target_part_joint_idx] = (float(l)) + elif state == 'open': + joint_angles[target_part_joint_idx] = float(r) + elif state == 'random-middle': + joint_angles[target_part_joint_idx] = float(get_random_number(l, r)) + elif state == 'random-middle-open': + joint_angles[target_part_joint_idx] = float(get_random_number(r * 0.8, r)) + elif state == 'random-closed-middle': + if np.random.random() < 0.5: + joint_angles[target_part_joint_idx] = float(get_random_number(l, r)) + else: + joint_angles[target_part_joint_idx] = float(l) + else: + raise ValueError('ERROR: object init state %s unknown!' % state) + return joint_angles + + def set_object_joint_angles(self, joint_angles): + self.object.set_qpos(joint_angles) + + def set_target_object_part_actor_id(self, actor_id): + if self.flog is not None: + self.flog.write('Set Target Object Part Actor ID: %d\n' % actor_id) + self.target_object_part_actor_id = actor_id + self.non_target_object_part_actor_id = list(set(self.all_link_ids) - set([actor_id])) + + # get the link handler + for j in self.object.get_joints(): + if j.get_dof() == 1: + if j.get_child_link().get_id() == actor_id: + self.target_object_part_actor_link = j.get_child_link() + + # moniter the target joint + idx = 0 + for j in self.object.get_joints(): + if j.get_dof() == 1: + if j.get_child_link().get_id() == actor_id: + self.target_object_part_joint_id = idx + self.target_object_part_joint_type = j.type + idx += 1 + + def get_object_qpos(self): + return self.object.get_qpos() + + def get_target_part_qpos(self): + qpos = self.object.get_qpos() + return float(qpos[self.target_object_part_joint_id]) + + def get_target_part_state(self): + qpos = self.object.get_qpos() + return float(qpos[self.target_object_part_joint_id]) - self.joint_angles_lower[self.target_object_part_joint_id] + + def get_target_part_pose(self): + return self.target_object_part_actor_link.get_pose() + + def start_checking_contact(self, robot_hand_actor_id, robot_gripper_actor_ids, strict): + self.check_contact = True + self.check_contact_strict = strict + self.first_timestep_check_contact = True + self.robot_hand_actor_id = robot_hand_actor_id + self.robot_gripper_actor_ids = robot_gripper_actor_ids + self.contact_error = False + + def end_checking_contact(self, robot_hand_actor_id, robot_gripper_actor_ids, strict): + self.check_contact = False + self.check_contact_strict = strict + self.first_timestep_check_contact = False + self.robot_hand_actor_id = robot_hand_actor_id + self.robot_gripper_actor_ids = robot_gripper_actor_ids + + def get_material(self, static_friction, dynamic_friction, restitution): + return self.engine.create_physical_material(static_friction, dynamic_friction, restitution) + + def render(self): + if self.show_gui and (not self.window): + self.window = True + self.renderer_controller.show_window() + self.scene.update_render() + if self.show_gui and (self.current_step % self.render_rate == 0): + self.renderer_controller.render() + + def step(self): + self.current_step += 1 + self.scene.step() + if self.check_contact: + if not self.check_contact_is_valid(): + raise ContactError() + + # check the first contact: only gripper links can touch the target object part link + def check_contact_is_valid(self): + self.contacts = self.scene.get_contacts() + contact = False; valid = False; + for c in self.contacts: + aid1 = c.actor1.get_id() + aid2 = c.actor2.get_id() + has_impulse = False + for p in c.points: + if abs(p.impulse @ p.impulse) > 1e-4: + has_impulse = True + break + if has_impulse: + if (aid1 in self.robot_gripper_actor_ids and aid2 == self.target_object_part_actor_id) or \ + (aid2 in self.robot_gripper_actor_ids and aid1 == self.target_object_part_actor_id): + contact, valid = True, True + if (aid1 in self.robot_gripper_actor_ids and aid2 in self.non_target_object_part_actor_id) or \ + (aid2 in self.robot_gripper_actor_ids and aid1 in self.non_target_object_part_actor_id): + if self.check_contact_strict: + self.contact_error = True + return False + else: + contact, valid = True, True + if (aid1 == self.robot_hand_actor_id or aid2 == self.robot_hand_actor_id): + if self.check_contact_strict: + self.contact_error = True + return False + else: + contact, valid = True, True + # starting pose should have no collision at all + if (aid1 in self.robot_gripper_actor_ids or aid1 == self.robot_hand_actor_id or \ + aid2 in self.robot_gripper_actor_ids or aid2 == self.robot_hand_actor_id) and self.first_timestep_check_contact: + self.contact_error = True + return False + + self.first_timestep_check_contact = False + if contact and valid: + self.check_contact = False + return True + + def check_contact_right(self): + contacts = self.scene.get_contacts() + if len(contacts) < len(self.robot_gripper_actor_ids): + # print("no enough contacts") + return False + # for c in contacts: + # aid1 = c.actor1.get_id() + # aid2 = c.actor2.get_id() + # if (aid1 in self.robot_gripper_actor_ids and aid2 == self.target_object_part_actor_id) or \ + # (aid2 in self.robot_gripper_actor_ids and aid1 == self.target_object_part_actor_id): + # print("right") + # pass + # elif (aid1 in self.robot_gripper_actor_ids and aid2 in self.non_target_object_part_actor_id) or \ + # (aid2 in self.robot_gripper_actor_ids and aid1 in self.non_target_object_part_actor_id): + # print("unright") + # right = False + # elif (aid1 == self.robot_hand_actor_id or aid2 == self.robot_hand_actor_id): + # print("hand contact") + # right = False + # elif (aid1 in self.robot_gripper_actor_ids and aid2 in self.robot_gripper_actor_ids): + # print("also non-successful grasp") + # right = False + # else: + # print(c.actor1.get_name(), c.actor2.get_name()) + right_aids = [] + for c in contacts: + aid1 = c.actor1.get_id() + aid2 = c.actor2.get_id() + if (aid1 in self.robot_gripper_actor_ids and aid2 == self.target_object_part_actor_id): + right_aids.append(aid1) + elif (aid2 in self.robot_gripper_actor_ids and aid1 == self.target_object_part_actor_id): + right_aids.append(aid2) + else: + pass + right = (set(right_aids) == set(self.robot_gripper_actor_ids)) + return right + + def close_render(self): + if self.window: + self.renderer_controller.hide_window() + self.window = False + + def wait_to_start(self): + print('press q to start\n') + while not self.renderer_controller.should_quit: + self.scene.update_render() + if self.show_gui: + self.renderer_controller.render() + + def close(self): + if self.show_gui: + self.renderer_controller.set_current_scene(None) + self.scene = None + + def get_global_mesh(self, obj): + final_vs = []; + final_fs = []; + vid = 0; + for l in obj.get_links(): + vs = [] + for s in l.get_collision_shapes(): + v = np.array(s.convex_mesh_geometry.vertices, dtype=np.float32) + f = np.array(s.convex_mesh_geometry.indices, dtype=np.uint32).reshape(-1, 3) + vscale = s.convex_mesh_geometry.scale + v[:, 0] *= vscale[0]; + v[:, 1] *= vscale[1]; + v[:, 2] *= vscale[2]; + ones = np.ones((v.shape[0], 1), dtype=np.float32) + v_ones = np.concatenate([v, ones], axis=1) + transmat = s.pose.to_transformation_matrix() + v = (v_ones @ transmat.T)[:, :3] + vs.append(v) + final_fs.append(f + vid) + vid += v.shape[0] + if len(vs) > 0: + vs = np.concatenate(vs, axis=0) + ones = np.ones((vs.shape[0], 1), dtype=np.float32) + vs_ones = np.concatenate([vs, ones], axis=1) + transmat = l.get_pose().to_transformation_matrix() + vs = (vs_ones @ transmat.T)[:, :3] + final_vs.append(vs) + final_vs = np.concatenate(final_vs, axis=0) + final_fs = np.concatenate(final_fs, axis=0) + return final_vs, final_fs + + def get_part_mesh(self, obj, part_id): + final_vs = []; + final_fs = []; + vid = 0; + for l in obj.get_links(): + l_id = l.get_id() + if l_id != part_id: + continue + vs = [] + for s in l.get_collision_shapes(): + v = np.array(s.convex_mesh_geometry.vertices, dtype=np.float32) + f = np.array(s.convex_mesh_geometry.indices, dtype=np.uint32).reshape(-1, 3) + vscale = s.convex_mesh_geometry.scale + v[:, 0] *= vscale[0]; + v[:, 1] *= vscale[1]; + v[:, 2] *= vscale[2]; + ones = np.ones((v.shape[0], 1), dtype=np.float32) + v_ones = np.concatenate([v, ones], axis=1) + transmat = s.pose.to_transformation_matrix() + v = (v_ones @ transmat.T)[:, :3] + vs.append(v) + final_fs.append(f + vid) + vid += v.shape[0] + if len(vs) > 0: + vs = np.concatenate(vs, axis=0) + ones = np.ones((vs.shape[0], 1), dtype=np.float32) + vs_ones = np.concatenate([vs, ones], axis=1) + transmat = l.get_pose().to_transformation_matrix() + vs = (vs_ones @ transmat.T)[:, :3] + final_vs.append(vs) + final_vs = np.concatenate(final_vs, axis=0) + final_fs = np.concatenate(final_fs, axis=0) + return final_vs, final_fs + + def sample_pc(self, v, f, n_points=4096): + mesh = trimesh.Trimesh(vertices=v, faces=f) + points, __ = trimesh.sample.sample_surface(mesh=mesh, count=n_points) + return points + + def check_drawer(self): + for j in self.object.get_joints(): + if j.get_dof() == 1 and (j.type == ArticulationJointType.PRISMATIC): + return True + return False + + def add_point_visual(self, point:np.ndarray, color=[1, 0, 0], radius=0.04, name='point_visual'): + self.visual_builder.add_sphere_visual(pose=Pose(p=point), radius=radius, color=color, name=name) + point_visual = self.visual_builder.build_static(name=name) + assert name not in self.visual_objects.keys() + self.visual_objects[name] = point_visual + + def add_line_visual(self, point1:np.ndarray, point2:np.ndarray, color=[1, 0, 0], width=0.03, name='line_visual'): + direction = point2 - point1 + direction = direction / np.linalg.norm(direction) + rotation = np.zeros((3, 3)) + temp2 = np.cross(direction, np.array([1., 0., 0.])) + if np.linalg.norm(temp2) < 1e-6: + temp1 = np.cross(np.array([0., 1., 0.]), direction) + temp1 /= np.linalg.norm(temp1) + temp2 = np.cross(direction, temp1) + temp2 /= np.linalg.norm(temp2) + else: + temp2 /= np.linalg.norm(temp2) + temp1 = np.cross(temp2, direction) + temp1 /= np.linalg.norm(temp1) + rotation[:, 0] = temp1 + rotation[:, 1] = temp2 + rotation[:, 2] = direction + pose_transformation = np.eye(4) + pose_transformation[:3, 3] = (point1+point2)/2 + pose_transformation[:3, :3] = rotation + pose = Pose().from_transformation_matrix(pose_transformation) + size = [width/2, width/2, np.linalg.norm(point1 - point2)/2] + self.visual_builder.add_box_visual(pose=pose, size=size, color=color, name=name) + line_visual = self.visual_builder.build_static(name=name) + assert name not in self.visual_objects.keys() + self.visual_objects[name] = line_visual + + def add_grasp_visual(self, grasp_width, grasp_depth, grasp_translation, grasp_rotation, affordance=0.5, name='grasp_visual'): + finger_width = 0.004 + tail_length = 0.04 + depth_base = 0.02 + gg_width = grasp_width + gg_depth = grasp_depth + gg_translation = grasp_translation + gg_rotation = grasp_rotation + + left = np.zeros((2, 3)) + left[0] = np.array([-depth_base - finger_width, -gg_width / 2, 0]) + left[1] = np.array([gg_depth, -gg_width / 2, 0]) + + right = np.zeros((2, 3)) + right[0] = np.array([-depth_base - finger_width, gg_width / 2, 0]) + right[1] = np.array([gg_depth, gg_width / 2, 0]) + + bottom = np.zeros((2, 3)) + bottom[0] = np.array([-depth_base - finger_width, -gg_width / 2, 0]) + bottom[1] = np.array([-depth_base - finger_width, gg_width / 2, 0]) + + tail = np.zeros((2, 3)) + tail[0] = np.array([-(tail_length + finger_width + depth_base), 0, 0]) + tail[1] = np.array([-(finger_width + depth_base), 0, 0]) + + vertices = np.vstack([left, right, bottom, tail]) + vertices = np.dot(gg_rotation, vertices.T).T + gg_translation + + if affordance < 0.5: + color = [1, 2*affordance, 0] + elif affordance == 1.0: + color = [0, 0, 1] + else: + color = [-2*affordance+2, 1, 0] + self.add_line_visual(vertices[0], vertices[1], color, width=0.005, name=name+'_left') + self.add_line_visual(vertices[2], vertices[3], color, width=0.005, name=name+'_right') + self.add_line_visual(vertices[4], vertices[5], color, width=0.005, name=name+'_bottom') + self.add_line_visual(vertices[6], vertices[7], color, width=0.005, name=name+'_tail') + + def add_frame_visual(self): + self.add_line_visual(np.array([0, 0, 0]), np.array([1, 0, 0]), [1, 0, 0], width=0.005, name='frame_x') + self.add_line_visual(np.array([0, 0, 0]), np.array([0, 1, 0]), [0, 1, 0], width=0.005, name='frame_y') + self.add_line_visual(np.array([0, 0, 0]), np.array([0, 0, 1]), [0, 0, 1], width=0.005, name='frame_z') + + def remove_visual(self, name): + self.visual_builder.remove_visual_at(self.visual_objects[name].get_id()) + self.scene.remove_actor(self.visual_objects[name]) + del self.visual_objects[name] + + def remove_grasp_visual(self, name): + self.remove_visual(name+'_left') + self.remove_visual(name+'_right') + self.remove_visual(name+'_bottom') + self.remove_visual(name+'_tail') + + def remove_frame_visual(self): + self.remove_visual('frame_x') + self.remove_visual('frame_y') + self.remove_visual('frame_z') + + def remove_all_visuals(self): + names = list(self.visual_objects.keys()) + for name in names: + self.remove_visual(name) diff --git a/envs/real_camera.py b/envs/real_camera.py new file mode 100644 index 0000000..f2a0218 --- /dev/null +++ b/envs/real_camera.py @@ -0,0 +1,210 @@ +import pyrealsense2 as rs +import numpy as np +import cv2 + +class CameraL515(object): + def __init__(self): + self.pipeline = rs.pipeline() + self.config = rs.config() + self.config.enable_stream(rs.stream.depth, 1024, 768, rs.format.z16, 30) + self.config.enable_stream(rs.stream.color, 1280, 720, rs.format.bgr8, 30) + self.align_to = rs.stream.color + self.align = rs.align(self.align_to) + + self.pipeline_profile = self.pipeline.start(self.config) + try: + self.device = self.pipeline_profile.get_device() + self.mtx = self.getIntrinsics() + + self.hole_filling = rs.hole_filling_filter() + + align_to = rs.stream.color + self.align = rs.align(align_to) + + # camera init warm up + i = 60 + while i>0: + frames = self.pipeline.wait_for_frames() + aligned_frames = self.align.process(frames) + depth_frame = aligned_frames.get_depth_frame() + color_frame = aligned_frames.get_color_frame() + # pdb.set_trace() + color_frame.get_profile().as_video_stream_profile().get_intrinsics() + if not depth_frame or not color_frame: + continue + depth_image = np.asanyarray(depth_frame.get_data()) + color_image = np.asanyarray(color_frame.get_data()) + i -= 1 + except: + self.__del__() + raise + + def getIntrinsics(self): + frames = self.pipeline.wait_for_frames() + aligned_frames = self.align.process(frames) + color_frame = aligned_frames.get_color_frame() + intrinsics = color_frame.get_profile().as_video_stream_profile().get_intrinsics() + mtx = [intrinsics.width,intrinsics.height,intrinsics.ppx,intrinsics.ppy,intrinsics.fx,intrinsics.fy] + camIntrinsics = np.array([[mtx[4],0,mtx[2]], + [0,mtx[5],mtx[3]], + [0,0,1.]]) + return camIntrinsics + + def get_data(self, hole_filling=False): + while True: + frames = self.pipeline.wait_for_frames() + aligned_frames = self.align.process(frames) + depth_frame = aligned_frames.get_depth_frame() + if hole_filling: + depth_frame = self.hole_filling.process(depth_frame) + color_frame = aligned_frames.get_color_frame() + if not depth_frame or not color_frame: + continue + depth_image = np.asanyarray(depth_frame.get_data()) + color_image = np.asanyarray(color_frame.get_data()) + break + return color_image, depth_image + + def get_data1(self, hole_filling=False): + while True: + frames = self.pipeline.wait_for_frames() + # aligned_frames = self.align.process(frames) + # depth_frame = aligned_frames.get_depth_frame() + # if hole_filling: + # depth_frame = self.hole_filling.process(depth_frame) + # color_frame = aligned_frames.get_color_frame() + depth_frame = frames.get_depth_frame() + color_frame = frames.get_color_frame() + if not depth_frame or not color_frame: + continue + colorizer = rs.colorizer() + depth_image = np.asanyarray(colorizer.colorize(depth_frame).get_data()) + # depth_image = np.asanyarray(depth_frame.get_data()) + color_image = np.asanyarray(color_frame.get_data()) + break + return color_image, depth_image + + def inpaint(self, img, missing_value=0): + ''' + pip opencv-python == 3.4.8.29 + :param image: + :param roi: [x0,y0,x1,y1] + :param missing_value: + :return: + ''' + # cv2 inpainting doesn't handle the border properly + # https://stackoverflow.com/questions/25974033/inpainting-depth-map-still-a-black-image-border + img = cv2.copyMakeBorder(img, 1, 1, 1, 1, cv2.BORDER_DEFAULT) + mask = (img == missing_value).astype(np.uint8) + + # Scale to keep as float, but has to be in bounds -1:1 to keep opencv happy. + scale = np.abs(img).max() + img = img.astype(np.float32) / scale # Has to be float32, 64 not supported. + img = cv2.inpaint(img, mask, 1, cv2.INPAINT_NS) + + # Back to original size and value range. + img = img[1:-1, 1:-1] + img = img * scale + return img + + def getXYZRGB(self,color, depth, robot_pose,camee_pose,camIntrinsics,inpaint=True,depth_scale=None): + ''' + :param color: + :param depth: + :param robot_pose: array 4*4 + :param camee_pose: array 4*4 + :param camIntrinsics: array 3*3 + :param inpaint: bool + :param depth_scale: float, change to meter unit + :return: xyzrgb + ''' + heightIMG, widthIMG, _ = color.shape + # pdb.set_trace() + # heightIMG = 720 + # widthIMG = 1280 + # depthImg = depth / 10000. + assert depth_scale is not None + depthImg = depth * depth_scale + # depthImg = depth + if inpaint: + depthImg = self.inpaint(depthImg) + robot_pose = np.dot(robot_pose, camee_pose) + + [pixX, pixY] = np.meshgrid(np.arange(widthIMG), np.arange(heightIMG)) + camX = (pixX - camIntrinsics[0][2]) * depthImg / camIntrinsics[0][0] + camY = (pixY - camIntrinsics[1][2]) * depthImg / camIntrinsics[1][1] + camZ = depthImg + + camPts = [camX.reshape(camX.shape + (1,)), camY.reshape(camY.shape + (1,)), camZ.reshape(camZ.shape + (1,))] + camPts = np.concatenate(camPts, 2) + camPts = camPts.reshape((camPts.shape[0] * camPts.shape[1], camPts.shape[2])) # shape = (heightIMG*widthIMG, 3) + worldPts = np.dot(robot_pose[:3, :3], camPts.transpose()) + robot_pose[:3, 3].reshape(3, + 1) # shape = (3, heightIMG*widthIMG) + rgb = color.reshape((-1, 3)) / 255. + rgb[:, [0, 2]] = rgb[:, [2, 0]] + xyzrgb = np.hstack((worldPts.T, rgb)) + # xyzrgb = self.getleft(xyzrgb) + return xyzrgb + + def getleft(self, obj1): + index = np.bitwise_and(obj1[:, 0] < 1.2, obj1[:, 0] > 0.2) + index = np.bitwise_and(obj1[:, 1] < 0.5, index) + index = np.bitwise_and(obj1[:, 1] > -0.5, index) + # index = np.bitwise_and(obj1[:, 2] > -0.1, index) + index = np.bitwise_and(obj1[:, 2] > 0.24, index) + index = np.bitwise_and(obj1[:, 2] < 0.6, index) + return obj1[index] + + + def __del__(self): + self.pipeline.stop() + + +def vis_pc(xyzrgb): + import open3d as o3d + camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + pc1 = o3d.geometry.PointCloud() + pc1.points = o3d.utility.Vector3dVector(xyzrgb[:, :3]) + pc1.colors = o3d.utility.Vector3dVector(xyzrgb[:, 3:]) + o3d.visualization.draw_geometries([camera_frame, pc1]) + + +if __name__ == "__main__": + print("initialize camera") + cam = CameraL515() + i = 0 + while True: + print(f"{i}th") + color, depth = cam.get_data(hole_filling=False) + + depth_sensor = cam.pipeline_profile.get_device().first_depth_sensor() + depth_scale = depth_sensor.get_depth_scale() + + # xyz in meter, rgb in [0, 1] + xyzrgb = cam.getXYZRGB(color, depth, np.identity(4), np.identity(4), cam.getIntrinsics(), inpaint=False, depth_scale=depth_scale) + # xyzrgb = xyzrgb[xyzrgb[:, 2] <= 1.5, :] + print(np.mean(xyzrgb[:, 2])) + vis_pc(xyzrgb) + + cv2.imshow('color', color) + while True: + if cv2.getWindowProperty('color', cv2.WND_PROP_VISIBLE) <= 0: + break + cv2.waitKey(1) + cv2.destroyAllWindows() + + cmd = input("whether save? (y/n): ") + if cmd == 'y': + cv2.imwrite(f"rgb_{i}.png", color) + np.savez(f"xyzrgb_{i}.npz", point_cloud=xyzrgb[:, :3], rgb=xyzrgb[:, 3:]) + i += 1 + elif cmd == 'n': + cmd = input("whether quit? (y/n): ") + if cmd == 'y': + break + elif cmd == 'n': + pass + else: + raise ValueError + else: + raise ValueError diff --git a/envs/real_robot.py b/envs/real_robot.py new file mode 100644 index 0000000..d8d5aee --- /dev/null +++ b/envs/real_robot.py @@ -0,0 +1,169 @@ +import time +import numpy as np +import transformations as tf +from frankx import Affine, JointMotion, Robot, Waypoint, WaypointMotion, Gripper, LinearRelativeMotion, LinearMotion, ImpedanceMotion + + +class Panda(): + def __init__(self,host='172.16.0.2'): + self.robot = Robot(host) + self.gripper = Gripper(host) + self.setGripper(20,0.1) + self.max_gripper_width = 0.08 + self.robot.set_default_behavior() + self.robot.recover_from_errors() + # Reduce the acceleration and velocity dynamic + self.robot.set_dynamic_rel(0.2) + # self.robot.set_dynamic_rel(0.05) + + # self.robot.velocity_rel = 0.1 + # self.robot.acceleration_rel = 0.02 + # self.robot.jerk_rel = 0.01 + + self.joint_tolerance = 0.01 + # state = self.robot.read_once() + # print('\nPose: ', self.robot.current_pose()) + # print('O_TT_E: ', state.O_T_EE) + # print('Joints: ', state.q) + # print('Elbow: ', state.elbow) + # pdb.set_trace() + self.in_impedance_control = False + + def setGripper(self,force=20.0, speed=0.02): + self.gripper.gripper_speed = speed # m/s + self.gripper.gripper_force = force # N + + def gripper_close(self) -> bool: + # can be used to grasp + is_graspped = self.gripper.clamp() + return is_graspped + + def gripper_open(self) -> None: + self.gripper.open() + + def gripper_release(self, width:float) -> None: + # can be used to release after grasping + self.gripper.release(min(max(width, 0.0), self.max_gripper_width)) + + def move_gripper(self, width:float) -> None: + self.gripper.move(min(max(width, 0.0), self.max_gripper_width), self.gripper.gripper_speed) # m + + def read_gripper(self) -> float: + return self.gripper.width() + + def is_grasping(self) -> bool: + return self.gripper.is_grasping() + + def moveJoint(self,joint,moveTarget=True): + assert len(joint)==7, "panda DOF is 7" + if not self.in_impedance_control: + self.robot.move(JointMotion(joint)) + else: + raise NotImplementedError + # while moveTarget: + # current = self.robot.read_once().q + # if all([np.abs(current[j] - joint[j]) None: + # joint = [-0.2918438928353856, -0.970780364569858, 0.10614118311070558, -1.3263296233896118, 0.28714199085241543, 1.4429661556967983, 0.8502855184922615] # microwave: pad + safe (rotate) + # joint = [-0.32146529861813555, -0.6174831717455548, 0.08796035485936884, -0.8542264670393085, 0.2846642250021548, 1.2692416279845777, 0.7918693021188179] # refrigerator: storagefurniture + # joint = [0.07228589984826875, -0.856545356798933, -0.005984785356738588, -1.446693722022207, -0.0739646362066269, 1.5132004619969288, 0.8178283093992272] # safe: pad + microwave + # joint = [-0.2249300478901436, -0.8004290411025673, 0.10279602963647609, -1.2284506426734476, 0.22189371273337696, 1.3787900806797873, 0.7783415511498849] # storagefurniture: microwave + # joint = [-0.2249300478901436, -0.8004290411025673, 0.10279602963647609, -1.2284506426734476, 0.22189371273337696, 1.3787900806797873, 0.7783415511498849] # drawer: microwave + joint = [-0.23090655681806208, -0.7368697004085705, 0.06469194421473157, -1.5633050220115945, 0.06594510726133981, 1.6454856337730452, 0.7169523042954654] # washingmachine: pad + microwave + self.moveJoint(joint) + + def readPose(self): + if not self.in_impedance_control: + pose = np.array(self.robot.read_once().O_T_EE).reshape(4, 4).T # EE2robot, gripper pose + else: + pose = np.array(self.impedance_motion.get_robotstate().O_T_EE).reshape(4, 4).T + return pose + + def movePose(self, pose): + # gripper pose + # tf.euler_from_matrix(pose, axes='rzyx') + # R.from_euler('ZYX', [-1.560670, -0.745688, 1.922058]).as_matrix(), rpy->matrix + if not self.in_impedance_control: + tr = pose[:3, 3] + rot = tf.euler_from_matrix(pose, axes='rzyx') + motion = LinearMotion(Affine(tr[0], tr[1], tr[2], rot[0], rot[1], rot[2])) + self.robot.move(motion) + else: + tr = pose[:3, 3] + rot = tf.euler_from_matrix(pose, axes='rzyx') + self.impedance_motion.target = Affine(tr[0], tr[1], tr[2], rot[0], rot[1], rot[2]) + + def start_impedance_control(self, tr_stiffness=1000.0, rot_stiffness=20.0): + print("you need rebuild frankx to support this") + self.impedance_motion = ImpedanceMotion(tr_stiffness, rot_stiffness) + self.robot_thread = self.robot.move_async(self.impedance_motion) + time.sleep(0.5) + self.in_impedance_control = True + + def end_impedance_control(self): + self.impedance_motion.finish() + self.robot_thread.join() + self.impedance_motion = None + self.robot_thread = None + self.in_impedance_control = False + + def readWrench(self): + if not self.in_impedance_control: + wrench = np.array(self.robot.read_once().O_F_ext_hat_K) # in base + else: + wrench = np.array(self.impedance_motion.get_robotstate().O_F_ext_hat_K) + return wrench + + +if __name__ == "__main__": + robot = Panda() + # test gripper + robot.gripper_close() + robot.gripper_open() + is_graspped = robot.gripper_close() + is_graspped = is_graspped and robot.is_grasping() + print("is_graspped:", is_graspped) + robot.move_gripper(0.05) + gripper_width = robot.read_gripper() + print("gripper width:", gripper_width) + # test arm + robot.homing() + joint = robot.readJoint() + print("current joint:", joint) + EE2robot = robot.readPose() + print("current pose:", EE2robot) + target_pose = EE2robot.copy() + target_pose[:3, 3] += np.array([0.05, 0.05, 0.05]) + robot.movePose(target_pose) + joint = robot.readJoint() + print("current joint:", joint) + EE2robot = robot.readPose() + print("current pose:", EE2robot) + robot.start_impedance_control() + for i in range(10): + current_pose = robot.readPose() + print(i, current_pose) + target_pose = current_pose.copy() + target_pose[:3, 3] += np.array([0., 0.02, 0.]) + robot.movePose(target_pose) + time.sleep(0.3) + EE2robot = robot.readPose() + print("current pose:", EE2robot) + robot.end_impedance_control() + robot.homing() + # pose = robot.readPose() + # np.save("pose_00.npy", pose) + # joints = robot.readJoint() + # np.save("joints_00.npy", joints) diff --git a/envs/robot.py b/envs/robot.py new file mode 100644 index 0000000..48bbe7d --- /dev/null +++ b/envs/robot.py @@ -0,0 +1,354 @@ +""" +Modified from https://github.com/warshallrho/VAT-Mart/blob/main/code/robots/panda_robot.py + Franka Panda Robot Arm + support panda.urdf, panda_gripper.urdf +""" + +from __future__ import division +import numpy as np +from PIL import Image +from sapien.core import Pose + +from .env import Env + + +def rot2so3(rotation): + assert rotation.shape == (3, 3) + if np.isclose(rotation.trace(), 3): + return np.zeros(3), 1 + if np.isclose(rotation.trace(), -1): + raise RuntimeError + theta = np.arccos((rotation.trace() - 1) / 2) + omega = 1 / 2 / np.sin(theta) * np.array( + [rotation[2, 1] - rotation[1, 2], rotation[0, 2] - rotation[2, 0], rotation[1, 0] - rotation[0, 1]]).T + return omega, theta + +def skew(vec): + return np.array([[0, -vec[2], vec[1]], + [vec[2], 0, -vec[0]], + [-vec[1], vec[0], 0]]) + +def pose2exp_coordinate(pose): + """ + Compute the exponential coordinate corresponding to the given SE(3) matrix + Note: unit twist is not a unit vector + + Args: + pose: (4, 4) transformation matrix + + Returns: + Unit twist: (6, ) vector represent the unit twist + Theta: scalar represent the quantity of exponential coordinate + """ + omega, theta = rot2so3(pose[:3, :3]) + ss = skew(omega) + inv_left_jacobian = np.eye(3, dtype=np.float) / theta - 0.5 * ss + ( + 1.0 / theta - 0.5 / np.tan(theta / 2)) * ss @ ss + v = inv_left_jacobian @ pose[:3, 3] + return np.concatenate([omega, v]), theta + +def adjoint_matrix(pose): + adjoint = np.zeros([6, 6]) + adjoint[:3, :3] = pose[:3, :3] + adjoint[3:6, 3:6] = pose[:3, :3] + adjoint[3:6, 0:3] = skew(pose[:3, 3]) @ pose[:3, :3] + return adjoint + + +class Robot(object): + def __init__(self, env:Env, urdf, material, open_gripper=False, scale=1.0): + self.env = env + self.timestep = env.scene.get_timestep() + + # load robot + loader = env.scene.create_urdf_loader() + loader.fix_root_link = True + loader.scale = scale + self.robot = loader.load(urdf, {"material": material}) + self.robot.name = "robot" + self.max_gripper_width = 0.08 + self.tcp2ee_length = 0.11 + self.scale = scale + + # hand (EE), two grippers, the rest arm joints (if any) + self.end_effector_index, self.end_effector = \ + [(i, l) for i, l in enumerate(self.robot.get_links()) if l.name == 'panda_hand'][0] + self.root2ee = self.end_effector.get_pose().to_transformation_matrix() @ self.robot.get_root_pose().inv().to_transformation_matrix() + self.hand_actor_id = self.end_effector.get_id() + self.gripper_joints = [joint for joint in self.robot.get_joints() if + joint.get_name().startswith("panda_finger_joint")] + self.gripper_actor_ids = [joint.get_child_link().get_id() for joint in self.gripper_joints] + self.arm_joints = [joint for joint in self.robot.get_joints() if + joint.get_dof() > 0 and not joint.get_name().startswith("panda_finger")] + self.g2g = np.array([[0., 0., 1.], [0., 1., 0.], [-1., 0., 0.]]) + + # set drive joint property + for joint in self.arm_joints: + joint.set_drive_property(1000, 400) + for joint in self.gripper_joints: + joint.set_drive_property(200, 60) + + # open/close the gripper at start + if open_gripper: + joint_angles = [] + for j in self.robot.get_joints(): + if j.get_dof() == 1: + if j.get_name().startswith("panda_finger_joint"): + joint_angles.append(self.max_gripper_width / 2.0 * scale) + else: + joint_angles.append(0) + self.robot.set_qpos(joint_angles) + + def load_gripper(self, urdf, material, open_gripper=False, scale=1.0): + raise NotImplementedError + self.timestep = self.env.scene.get_timestep() + + # load robot + loader = self.env.scene.create_urdf_loader() + loader.fix_root_link = True + loader.scale = scale + self.robot = loader.load(urdf, {"material": material}) + self.robot.name = "robot" + self.max_gripper_width = 0.08 + self.tcp2ee_length = 0.11 + self.scale = scale + + # hand (EE), two grippers, the rest arm joints (if any) + self.end_effector_index, self.end_effector = \ + [(i, l) for i, l in enumerate(self.robot.get_links()) if l.name == 'panda_hand'][0] + self.root2ee = self.end_effector.get_pose().to_transformation_matrix() @ self.robot.get_root_pose().inv().to_transformation_matrix() + self.hand_actor_id = self.end_effector.get_id() + self.gripper_joints = [joint for joint in self.robot.get_joints() if + joint.get_name().startswith("panda_finger_joint")] + self.gripper_actor_ids = [joint.get_child_link().get_id() for joint in self.gripper_joints] + self.arm_joints = [joint for joint in self.robot.get_joints() if + joint.get_dof() > 0 and not joint.get_name().startswith("panda_finger")] + + # set drive joint property + for joint in self.arm_joints: + joint.set_drive_property(1000, 400) + for joint in self.gripper_joints: + joint.set_drive_property(200, 60) + + # open/close the gripper at start + if open_gripper: + joint_angles = [] + for j in self.robot.get_joints(): + if j.get_dof() == 1: + if j.get_name().startswith("panda_finger_joint"): + joint_angles.append(self.max_gripper_width / 2.0 * scale) + else: + joint_angles.append(0) + self.robot.set_qpos(joint_angles) + + def compute_joint_velocity_from_twist(self, twist: np.ndarray) -> np.ndarray: + """ + This function is a kinematic-level calculation which do not consider dynamics. + Pay attention to the frame of twist, is it spatial twist or body twist + + Jacobian is provided for your, so no need to compute the velocity kinematics + ee_jacobian is the geometric Jacobian on account of only the joint of robot arm, not gripper + Jacobian in SAPIEN is defined as the derivative of spatial twist with respect to joint velocity + + Args: + twist: (6,) vector to represent the twist + + Returns: + (7, ) vector for the velocity of arm joints (not include gripper) + + """ + assert twist.size == 6 + # Jacobian define in SAPIEN use twist (v, \omega) which is different from the definition in the slides + # So we perform the matrix block operation below + dense_jacobian = self.robot.compute_spatial_twist_jacobian() # (num_link * 6, dof()) + ee_jacobian = np.zeros([6, self.robot.dof - 2]) + ee_jacobian[:3, :] = dense_jacobian[self.end_effector_index * 6 - 3: self.end_effector_index * 6, :self.robot.dof - 2] + ee_jacobian[3:6, :] = dense_jacobian[(self.end_effector_index - 1) * 6: self.end_effector_index * 6 - 3, :self.robot.dof - 2] + + inverse_jacobian = np.linalg.pinv(ee_jacobian, rcond=1e-2) + return inverse_jacobian @ twist + + def internal_controller(self, qvel: np.ndarray) -> None: + """Control the robot dynamically to execute the given twist for one time step + + This method will try to execute the joint velocity using the internal dynamics function in SAPIEN. + + Note that this function is only used for one time step, so you may need to call it multiple times in your code + Also this controller is not perfect, it will still have some small movement even after you have finishing using + it. Thus try to wait for some steps using self.wait_n_steps(n) like in the hw2.py after you call it multiple + time to allow it to reach the target position + + Args: + qvel: (7,) vector to represent the joint velocity + + """ + assert qvel.size == len(self.arm_joints) + target_qpos = qvel * self.timestep + self.robot.get_drive_target()[:-2] + for i, joint in enumerate(self.arm_joints): + joint.set_drive_velocity_target(qvel[i]) + joint.set_drive_target(target_qpos[i]) + passive_force = self.robot.compute_passive_force() + self.robot.set_qf(passive_force) + + def calculate_twist(self, time_to_target, target_ee_pose): + relative_transform = self.end_effector.get_pose().inv().to_transformation_matrix() @ target_ee_pose + unit_twist, theta = pose2exp_coordinate(relative_transform) + velocity = theta / time_to_target + body_twist = unit_twist * velocity + current_ee_pose = self.end_effector.get_pose().to_transformation_matrix() + return adjoint_matrix(current_ee_pose) @ body_twist + + def set_pose(self, target_ee_pose:np.ndarray, gripper_depth:float): + # target_ee_pose: (4, 4) transformation of robot tcp in world frame, as grasp pose + root_pose = np.identity(4) + root_pose[:3, :3] = target_ee_pose[:3, :3] @ self.g2g + root_pose[:3, 3] = target_ee_pose[:3, 3] + root_pose[:3, 3] -= (self.tcp2ee_length * self.scale - gripper_depth) * root_pose[:3, 2] + root_pose = np.linalg.inv(self.root2ee) @ root_pose + self.robot.set_root_pose(Pose().from_transformation_matrix(root_pose)) + + def get_pose(self, gripper_depth:float) -> np.ndarray: + target_ee_pose = np.identity(4) + # root_pose = self.robot.get_root_pose().to_transformation_matrix() + root_pose = self.end_effector.get_pose().to_transformation_matrix() + target_ee_pose[:3, 3] = root_pose[:3, 3] + (self.tcp2ee_length * self.scale - gripper_depth) * root_pose[:3, 2] + target_ee_pose[:3, :3] = root_pose[:3, :3] @ np.linalg.inv(self.g2g) + return target_ee_pose + + def move_to_target_pose(self, target_ee_pose: np.ndarray, gripper_depth:float, num_steps: int, visu=None, vis_gif=False, vis_gif_interval=200, cam=None) -> None: + """ + Move the robot hand dynamically to a given target pose + Args: + target_ee_pose: (4, 4) transformation of robot tcp in world frame, as grasp pose + num_steps: how much steps to reach to target pose, + each step correspond to self.scene.get_timestep() seconds + in physical simulation + """ + if visu: + waypoints = [] + if vis_gif: + imgs = [] + + executed_time = num_steps * self.timestep + + target_ee_root_pose = np.identity(4) + target_ee_root_pose[:3, :3] = target_ee_pose[:3, :3] @ self.g2g + target_ee_root_pose[:3, 3] = target_ee_pose[:3, 3] + target_ee_root_pose[:3, 3] -= (self.tcp2ee_length * self.scale - gripper_depth) * target_ee_root_pose[:3, 2] + + spatial_twist = self.calculate_twist(executed_time, target_ee_root_pose) + for i in range(num_steps): + if i % 100 == 0: + spatial_twist = self.calculate_twist((num_steps - i) * self.timestep, target_ee_root_pose) + qvel = self.compute_joint_velocity_from_twist(spatial_twist) + self.internal_controller(qvel) + self.env.step() + self.env.render() + if visu and i % 200 == 0: + waypoints.append(self.robot.get_qpos().tolist()) + if vis_gif and ((i + 1) % vis_gif_interval == 0): + rgb_pose, _ = cam.get_observation() + fimg = (rgb_pose*255).astype(np.uint8) + fimg = Image.fromarray(fimg) + imgs.append(fimg) + if vis_gif and (i == 0): + rgb_pose, _ = cam.get_observation() + fimg = (rgb_pose*255).astype(np.uint8) + fimg = Image.fromarray(fimg) + for idx in range(5): + imgs.append(fimg) + + if visu and not vis_gif: + return waypoints + if vis_gif and not visu: + return imgs + if visu and vis_gif: + return imgs, waypoints + + def move_to_target_qvel(self, qvel) -> None: + + """ + Move the robot hand dynamically to a given target pose + Args: + target_ee_pose: (4, 4) transformation of robot hand in robot base frame (ee2base) + num_steps: how much steps to reach to target pose, + each step correspond to self.scene.get_timestep() seconds + in physical simulation + """ + assert qvel.size == len(self.arm_joints) + for idx_step in range(100): + target_qpos = qvel * self.timestep + self.robot.get_drive_target()[:-2] + for i, joint in enumerate(self.arm_joints): + joint.set_drive_velocity_target(qvel[i]) + joint.set_drive_target(target_qpos[i]) + passive_force = self.robot.compute_passive_force() + self.robot.set_qf(passive_force) + self.env.step() + self.env.render() + return + + + def close_gripper(self): + for joint in self.gripper_joints: + joint.set_drive_target(0.0) + + def open_gripper(self): + for joint in self.gripper_joints: + joint.set_drive_target(self.max_gripper_width / 2.0) + + def set_gripper(self, width:float): + joint_angles = [] + for j in self.robot.get_joints(): + if j.get_dof() == 1: + if j.get_name().startswith("panda_finger_joint"): + joint_angles.append(max(min(width, self.max_gripper_width) / 2.0, 0.0)) + else: + joint_angles.append(0) + self.robot.set_qpos(joint_angles) + for joint in self.gripper_joints: + joint.set_drive_target(max(min(width, self.max_gripper_width) / 2.0, 0.0)) + + def get_gripper(self) -> float: + joint_angles = self.robot.get_qpos() + width = 0.0 + j_idx = 0 + for j in self.robot.get_joints(): + if j.get_dof() == 1: + if j.get_name().startswith("panda_finger_joint"): + width += joint_angles[j_idx] + else: + pass + j_idx += 1 + return width + + def clear_velocity_command(self): + for joint in self.arm_joints: + joint.set_drive_velocity_target(0) + + def wait_n_steps(self, n: int, visu=None, vis_gif=False, vis_gif_interval=200, cam=None): + imgs = [] + if visu: + waypoints = [] + self.clear_velocity_command() + for i in range(n): + passive_force = self.robot.compute_passive_force() + self.robot.set_qf(passive_force) + self.env.step() + self.env.render() + if visu and i % 200 == 0: + waypoints.append(self.robot.get_qpos().tolist()) + if vis_gif and ((i + 1) % vis_gif_interval == 0): + rgb_pose, _ = cam.get_observation() + fimg = (rgb_pose*255).astype(np.uint8) + fimg = Image.fromarray(fimg) + imgs.append(fimg) + # + self.robot.set_qf([0] * self.robot.dof) + if visu and vis_gif: + return imgs, waypoints + + if visu: + return waypoints + if vis_gif: + return imgs + diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..50e8b05 --- /dev/null +++ b/eval.py @@ -0,0 +1,691 @@ +import configargparse +import json +from omegaconf import OmegaConf +import random +import os +import logging +import numpy as np +import tqdm +import time +import datetime +import torch +import imageio +from PIL import Image +import transformations as tf +import open3d as o3d + +from envs.env import Env +from envs.camera import Camera +from envs.robot import Robot +from utilities.env_utils import setup_seed +from utilities.data_utils import transform_pc, transform_dir, read_joints_from_urdf_file, pc_noise +from utilities.metrics_utils import calc_pose_error, invaffordance_metrics, invaffordances2affordance, calc_translation_error, calc_direction_error +from utilities.constants import seed, max_grasp_width + + +def config_parse() -> configargparse.Namespace: + parser = configargparse.ArgumentParser() + + # environment config + parser.add_argument('--camera_config_path', type=str, default='./configs/data/camera_config.json', help='the path to camera config') + parser.add_argument('--num_config_per_object', type=int, default=200, help='the number of configs per object') + # data config + parser.add_argument('--scale', type=float, default=0, help='the scale of the object') + parser.add_argument('--data_path', type=str, default='/data2/junbo/where2act_modified_sapien_dataset/7167/mobility_vhacd.urdf', help='the path to the data') + parser.add_argument('--cat', type=str, default='Microwave', help='the category of the object') + parser.add_argument('--object_config_path', type=str, default='./configs/data/object_config.json', help='the path to object config') + # robot config + parser.add_argument('--robot_urdf_path', type=str, default='/data2/junbo/franka_panda/panda_gripper.urdf', help='the path to robot urdf') + parser.add_argument('--robot_scale', type=float, default=1, help='the scale of the robot') + # gt config + parser.add_argument('--gt_path', type=str, default='/data2/junbo/where2act_modified_sapien_dataset/7167/joint_abs_pose.json', help='the path to gt') + # model config + parser.add_argument('--roartnet', action='store_true', help='whether call roartnet') + parser.add_argument('--roartnet_config_path', type=str, default='./configs/eval_config.yaml', help='the path to roartnet config') + # grasp config + parser.add_argument('--graspnet', action='store_true', help='whether call graspnet') + parser.add_argument('--grasp', action='store_true', help='whether grasp') + parser.add_argument('--gsnet_weight_path', type=str, default='./weights/checkpoint_detection.tar', help='the path to graspnet weight') + parser.add_argument('--show_all_grasps', action='store_true', help='whether show all grasps detected by graspnet, note the visuals will harm the speed') + # task config + parser.add_argument('--selected_part', type=int, default=0, help='the selected part of the object') + # parser.add_argument('--manipulation', type=float, default=-20, help='the manipulation task, positive as push, negative as pull, revolute in degree, prismatic in cm') + parser.add_argument('--task', type=str, default='pull', choices=['pull', 'push', 'none'], help='the task') + parser.add_argument('--task_low', type=float, default=0.1, help='low bound of task') + parser.add_argument('--task_high', type=float, default=0.7, help='high bound of task') + parser.add_argument('--success_threshold', type=float, default=0.15, help='success threshold for ratio of manipulated movement') + # others + parser.add_argument('--gui', action='store_true', help='whether show gui') + parser.add_argument('--video', action='store_true', help='whether save video') + parser.add_argument('--output_path', type=str, default='./outputs/manipulation', help='the path to output') + parser.add_argument('--abbr', type=str, default='7167', help='the abbr of the object') + parser.add_argument('--seed', type=int, default=seed, help='the random seed') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = config_parse() + setup_seed(args.seed) + # TODO: hardcode here to add noise to point cloud + output_name = 'noise_' + args.cat + '_' + args.abbr + '_' + args.task + if args.roartnet: + output_name += '_roartnet' + output_name += '_' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + output_path = os.path.join(args.output_path, output_name) + os.makedirs(output_path, exist_ok=True) + logger = logging.getLogger("manipulation") + logger.setLevel(level=logging.DEBUG) + handler = logging.FileHandler(filename=os.path.join(output_path, 'log.txt')) + logger.addHandler(handler) + + if args.roartnet: + # load roartnet + start_time = time.time() + from models.roartnet import create_shot_encoder, create_encoder + from inference import inference_fn as roartnet_inference_fn + roartnet_cfg = OmegaConf.load(args.roartnet_config_path) + trained_path = roartnet_cfg.trained.path[args.cat] + trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml") + roartnet_cfg = OmegaConf.merge(trained_cfg, roartnet_cfg) + joint_num = roartnet_cfg.dataset.joint_num + resolution = roartnet_cfg.dataset.resolution + receptive_field = roartnet_cfg.dataset.receptive_field + has_rgb = roartnet_cfg.dataset.rgb + denoise = roartnet_cfg.dataset.denoise + normalize = roartnet_cfg.dataset.normalize + sample_points_num = roartnet_cfg.dataset.sample_points_num + sample_tuples_num = roartnet_cfg.algorithm.sampling.sample_tuples_num + tuple_more_num = roartnet_cfg.algorithm.sampling.tuple_more_num + shot_hidden_dims = roartnet_cfg.algorithm.shot_encoder.hidden_dims + shot_feature_dim = roartnet_cfg.algorithm.shot_encoder.feature_dim + shot_bn = roartnet_cfg.algorithm.shot_encoder.bn + shot_ln = roartnet_cfg.algorithm.shot_encoder.ln + shot_droput = roartnet_cfg.algorithm.shot_encoder.dropout + shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim, + shot_bn, shot_ln, shot_droput) + shot_encoder.load_state_dict(torch.load(f'{trained_path}/weights/shot_encoder_latest.pth', map_location=torch.device('cuda'))) + shot_encoder = shot_encoder.cuda() + shot_encoder.eval() + overall_hidden_dims = roartnet_cfg.algorithm.encoder.hidden_dims + rot_bin_num = roartnet_cfg.algorithm.voting.rot_bin_num + overall_bn = roartnet_cfg.algorithm.encoder.bn + overall_ln = roartnet_cfg.algorithm.encoder.ln + overall_dropout = roartnet_cfg.algorithm.encoder.dropout + encoder = create_encoder(tuple_more_num, shot_feature_dim, has_rgb, overall_hidden_dims, rot_bin_num, joint_num, + overall_bn, overall_ln, overall_dropout) + encoder.load_state_dict(torch.load(f'{trained_path}/weights/encoder_latest.pth', map_location=torch.device('cuda'))) + encoder = encoder.cuda() + encoder.eval() + voting_num = roartnet_cfg.algorithm.voting.voting_num + angle_tol = roartnet_cfg.algorithm.voting.angle_tol + translation2pc = roartnet_cfg.algorithm.voting.translation2pc + multi_candidate = roartnet_cfg.algorithm.voting.multi_candidate + candidate_threshold = roartnet_cfg.algorithm.voting.candidate_threshold + rotation_multi_neighbor = roartnet_cfg.algorithm.voting.rotation_multi_neighbor + neighbor_threshold = roartnet_cfg.algorithm.voting.neighbor_threshold + rotation_cluster = roartnet_cfg.algorithm.voting.rotation_cluster + bmm_size = roartnet_cfg.algorithm.voting.bmm_size + end_time = time.time() + logger.info(f"===> loaded roartnet {end_time - start_time}") + + if args.graspnet: + # load graspnet + start_time = time.time() + from munch import DefaultMunch + from gsnet import AnyGrasp + grasp_detector_cfg = { + 'checkpoint_path': args.gsnet_weight_path, + 'max_gripper_width': max_grasp_width * args.robot_scale, + 'gripper_height': 0.03, + 'top_down_grasp': False, + 'add_vdistance': True, + 'debug': True + } + grasp_detector_cfg = DefaultMunch.fromDict(grasp_detector_cfg) + grasp_detector = AnyGrasp(grasp_detector_cfg) + grasp_detector.load_net() + end_time = time.time() + logger.info(f"===> loaded graspnet {end_time - start_time}") + + # initialize environment + start_time = time.time() + env = Env(show_gui=args.gui) + if args.camera_config_path == 'none': + camera_config = None + cam = Camera(env, random_position=True, restrict_dir=True) + else: + with open(args.camera_config_path, "r") as fp: + camera_config = json.load(fp) + camera_near = camera_config['intrinsics']['near'] + camera_far = camera_config['intrinsics']['far'] + camera_width = camera_config['intrinsics']['width'] + camera_height = camera_config['intrinsics']['height'] + camera_fovx = np.random.uniform(camera_config['intrinsics']['fovx'][0], camera_config['intrinsics']['fovx'][1]) + camera_fovy = np.random.uniform(camera_config['intrinsics']['fovy'][0], camera_config['intrinsics']['fovy'][1]) + camera_dist = np.random.uniform(camera_config['extrinsics']['dist'][0], camera_config['extrinsics']['dist'][1]) + camera_phi = np.random.uniform(camera_config['extrinsics']['phi'][0], camera_config['extrinsics']['phi'][1]) + camera_theta = np.random.uniform(camera_config['extrinsics']['theta'][0], camera_config['extrinsics']['theta'][1]) + cam = Camera(env, near=camera_near, far=camera_far, image_size=[camera_width, camera_height], fov=[camera_fovx, camera_fovy], + dist=camera_dist, phi=camera_phi / 180 * np.pi, theta=camera_theta / 180 * np.pi) + if args.gui: + env.set_controller_camera_pose(cam.pos[0], cam.pos[1], cam.pos[2], np.pi + cam.theta, -cam.phi) + object_material = env.get_material(4, 4, 0.01) + if args.object_config_path == 'none': + object_config = { + "scale_min": 1, + "scale_max": 1, + } + else: + with open(args.object_config_path, "r") as fp: + object_config = json.load(fp) + if args.gt_path != "none": + with open(args.gt_path, "r") as fp: + gt_config = json.load(fp) + max_object_size = 0 + for joint_name in gt_config.keys(): + if joint_name == 'aabb_min' or joint_name == 'aabb_max': + continue + max_object_size = max(max_object_size, np.max(np.array(gt_config[joint_name]["bbox_max"]) - np.array(gt_config[joint_name]["bbox_min"]))) + if args.cat in object_config: + object_scale = object_config[args.cat]['size'] / max_object_size + joint_num = object_config[args.cat]['joint_num'] + assert joint_num == len(gt_config.keys()) - 2 + else: + object_scale = 1 + joint_num = len(gt_config.keys()) - 2 + if args.scale != 0: + object_scale = args.scale + else: + object_scale = args.scale + end_time = time.time() + logger.info(f"===> initialized environment {end_time - start_time}") + + success_configs, fail_configs = [], [] + for config_id in tqdm.trange(args.num_config_per_object): + # load object + this_config_scale = np.random.uniform(object_config["scale_min"], object_config["scale_max"]) * object_scale + video = [] + start_time = time.time() + still = False + try_times = 0 + while not still and try_times < 5: + goal_qpos, joint_abs_angles = env.load_object(args.data_path, object_material, state='random-middle-middle', target_part_id=-1, target_part_idx=args.selected_part, scale=this_config_scale) + env.render() + + # check still and reach goal qpos + start_time = time.time() + still_timesteps = 0 + wait_timesteps = 0 + cur_qpos = env.get_object_qpos() + # while still_timesteps < 500 and wait_timesteps < 3000: + while still_timesteps < 500 and wait_timesteps < 5000: + env.step() + env.render() + cur_new_qpos = env.get_object_qpos() + invalid_contact = False + for c in env.scene.get_contacts(): + for p in c.points: + if abs(p.impulse @ p.impulse) > 1e-4: + invalid_contact = True + break + if invalid_contact: + break + # if np.max(np.abs(cur_new_qpos - cur_qpos)) < 1e-6 and (not invalid_contact): + if np.max(np.abs(cur_new_qpos - cur_qpos)) < 1e-6 and np.max(np.abs(cur_new_qpos - goal_qpos)) < 0.02 and (not invalid_contact): + still_timesteps += 1 + else: + still_timesteps = 0 + cur_qpos = cur_new_qpos + wait_timesteps += 1 + still = still_timesteps >= 500 + if not still: + env.scene.remove_articulation(env.object) + try_times += 1 + if not still: + logger.info(f"{config_id} failed to load object") + continue + end_time = time.time() + logger.info(f"===> {config_id} loaded object {end_time - start_time} {this_config_scale} {try_times}") + + # set camera + if camera_config is None: + cam.change_pose(random_position=True, restrict_dir=True) + else: + camera_fovx = np.random.uniform(camera_config['intrinsics']['fovx'][0], camera_config['intrinsics']['fovx'][1]) + camera_fovy = np.random.uniform(camera_config['intrinsics']['fovy'][0], camera_config['intrinsics']['fovy'][1]) + cam.change_fov([camera_fovx, camera_fovy]) + camera_dist = np.random.uniform(camera_config['extrinsics']['dist'][0], camera_config['extrinsics']['dist'][1]) + camera_phi = np.random.uniform(camera_config['extrinsics']['phi'][0], camera_config['extrinsics']['phi'][1]) + camera_theta = np.random.uniform(camera_config['extrinsics']['theta'][0], camera_config['extrinsics']['theta'][1]) + cam.change_pose(dist=camera_dist, phi=camera_phi / 180 * np.pi, theta=camera_theta / 180 * np.pi) + if args.gui: + env.set_controller_camera_pose(cam.pos[0], cam.pos[1], cam.pos[2], np.pi + cam.theta, -cam.phi) + env.step() + env.render() + if args.video: + frame, _ = cam.get_observation() + frame = (frame * 255).astype(np.uint8) + frame = Image.fromarray(frame) + video.append(frame) + + # get observation + rgb, depth = cam.get_observation() + cam_XYZA_id1, cam_XYZA_id2, cam_XYZA_pts = cam.compute_camera_XYZA(depth) + mask = cam.camera.get_segmentation() # used only for filtering out bad cases + object_mask = mask[cam_XYZA_id1, cam_XYZA_id2] + extrinsic = cam.get_metadata()['mat44'] + R = extrinsic[:3, :3] + T = extrinsic[:3, 3] + pcd_cam = cam_XYZA_pts.copy() + pcd_world = (R @ cam_XYZA_pts.T).T + T + pcd_color = rgb[cam_XYZA_id1, cam_XYZA_id2] + c2c = np.array([[0, -1, 0, 0], [0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 0, 1]]) + pcd_grasp = transform_pc(pcd_cam, c2c) + # TODO: hardcode here to add noise to point cloud + pcd_cam = pc_noise(pcd_cam, 0.2, 0.01, 0.002, 0.5) + + # obtain gt + start_time = time.time() + # env.add_frame_visual() + # env.render() + if args.gt_path != "none": + # set task + movable_link_ids = env.movable_link_ids + object_joint_types = env.movable_link_joint_types + movable_link_joint_names = env.movable_link_joint_names + all_link_ids = env.all_link_ids + all_link_names = env.all_link_names + joint_real_states = env.get_object_qpos() + joint_lower_states = env.joint_angles_lower + joint_upper_states = env.joint_angles_upper + env.set_target_object_part_actor_id(movable_link_ids[args.selected_part]) + if object_joint_types[args.selected_part] == 0: + target_joint_real_state = joint_real_states[args.selected_part] / np.pi * 180 + target_joint_lower_state = joint_lower_states[args.selected_part] / np.pi * 180 + target_joint_upper_state = joint_upper_states[args.selected_part] / np.pi * 180 + elif object_joint_types[args.selected_part] == 1: + target_joint_real_state = joint_real_states[args.selected_part] * 100 + target_joint_lower_state = joint_lower_states[args.selected_part] * 100 + target_joint_upper_state = joint_upper_states[args.selected_part] * 100 + else: + raise ValueError(f"invalid joint type {object_joint_types[args.selected_part]}") + if args.task == 'push': + task_high = min(args.task_high * (target_joint_upper_state - target_joint_lower_state), target_joint_real_state - target_joint_lower_state - 15) + elif args.task == 'pull': + task_high = min(args.task_high * (target_joint_upper_state - target_joint_lower_state), target_joint_upper_state - target_joint_real_state - 5) + elif args.task == 'none': + task_high = target_joint_upper_state - target_joint_lower_state + else: + raise ValueError(f"invalid task {args.task}") + task_low = args.task_low * (target_joint_upper_state - target_joint_lower_state) + if task_high <= 0 or task_low >= task_high: + logger.info(f"{config_id} cannot set task {target_joint_upper_state - target_joint_lower_state} {target_joint_real_state - target_joint_lower_state}") + env.scene.remove_articulation(env.object) + continue + if args.task == 'push': + task_state = random.random() * (task_high - task_low) + task_low + elif args.task == 'pull': + task_state = random.random() * (task_high - task_low) + task_low + task_state *= -1 + elif args.task == 'none': + task_state = 0 + else: + raise ValueError(f"invalid task {args.task}") + assert len(movable_link_ids) == len(object_joint_types) and len(object_joint_types) == len(movable_link_joint_names) + assert len(movable_link_ids) == len(joint_real_states) and len(joint_real_states) == len(joint_lower_states) and len(joint_lower_states) == len(joint_upper_states) + logger.info(f"{config_id} task {task_state}") + + joint_bases, joint_directions, joint_types, joint_res, joint_states, affordable_positions = [], [], [], [], [], [] + for idx, link_id in enumerate(movable_link_ids): + joint_pose_meta = gt_config[movable_link_joint_names[idx]] + joint_base = np.asarray(joint_pose_meta['base_position'], order='F') * this_config_scale + joint_bases.append(joint_base) + # env.add_point_visual(joint_base, color=[0, 1, 0], name='joint_base_{}'.format(idx)) + joint_direction = np.asarray(joint_pose_meta['direction'], order='F') + joint_directions.append(joint_direction) + # env.add_line_visual(joint_base, joint_base + this_config_scale * joint_direction, color=[0, 1, 0], name='joint_direction_{}'.format(idx)) + joint_type = joint_pose_meta['joint_type'] + assert joint_type == object_joint_types[idx] + joint_types.append(joint_type) + if joint_type == 0: + joint_re = joint_pose_meta["joint_re"] + elif joint_type == 1: + joint_re = 0 + else: + raise ValueError(f"invalid joint type {joint_pose_meta['joint_type']}") + joint_res.append(joint_re) + joint_state = joint_real_states[idx] - joint_lower_states[idx] + joint_states.append(joint_state) + affordable_position = np.asarray(joint_pose_meta['affordable_position'], order='F') * this_config_scale + if joint_type == 0: + transformation_matrix = tf.rotation_matrix(angle=joint_state * joint_re, direction=joint_direction, point=joint_base) + elif joint_type == 1: + transformation_matrix = tf.translation_matrix(joint_state * joint_direction) + else: + raise ValueError(f"invalid joint type {joint_pose_meta['joint_type']}") + affordable_position = transform_pc(affordable_position[None, :], transformation_matrix)[0] + affordable_positions.append(affordable_position) + # env.add_point_visual(affordable_position, color=[0, 0, 1], name='affordable_position_{}'.format(idx)) + if joint_type == 0: + logger.info(f"added joint {movable_link_joint_names[idx]} revolute {'pull_counterclockwise' if joint_re == 1 else 'pull_clockwise'} {joint_state / np.pi * 180} {(joint_upper_states[idx] - joint_lower_states[idx]) / np.pi * 180}") + elif joint_type == 1: + logger.info(f"added joint {movable_link_joint_names[idx]} prismatic {joint_state * 100} {(joint_upper_states[idx] - joint_lower_states[idx]) * 100}") + else: + raise ValueError(f"invalid joint type {joint_pose_meta['joint_type']}") + selected_joint_base = joint_bases[args.selected_part] + selected_joint_direction = joint_directions[args.selected_part] + selected_joint_type = joint_types[args.selected_part] + selected_joint_re = joint_res[args.selected_part] + selected_joint_state = joint_states[args.selected_part] + selected_affordable_position = affordable_positions[args.selected_part] + + # merge affiliated parts into parents + joints_dict = read_joints_from_urdf_file(args.data_path) + link_graph = {} + for joint_name in joints_dict: + link_graph[joints_dict[joint_name]['child']] = joints_dict[joint_name]['parent'] + mask_ins = np.unique(object_mask) + fixed_id = -1 + for mask_in in mask_ins: + if mask_in not in movable_link_ids: + link_name = all_link_names[all_link_ids.index(mask_in)] + parent_link = link_graph[link_name] + parent_link_id = all_link_ids[all_link_names.index(parent_link)] + if parent_link_id in mask_ins: + if parent_link_id in movable_link_ids: + object_mask[object_mask == mask_in] = parent_link_id + else: + # TODO: may be error + if fixed_id == -1: + fixed_id = parent_link_id + object_mask[object_mask == mask_in] = fixed_id + else: + # TODO: may be error + if fixed_id == -1: + fixed_id = parent_link_id + object_mask[object_mask == mask_in] = fixed_id + mask_ins = np.unique(object_mask) + fixed_num = 0 + for mask_in in mask_ins: + if mask_in not in movable_link_ids: + fixed_num += 1 + assert fixed_num <= 1 + assert mask_ins.shape[0] <= len(movable_link_ids) + 1 + + # rearrange mask + instance_mask = np.zeros_like(object_mask) + real_id = 1 + selected_joint_idxs = [] + for mask_id, movable_id in zip(movable_link_ids, range(len(movable_link_ids))): + if mask_id in mask_ins: + instance_mask[object_mask == mask_id] = real_id + real_id += 1 + selected_joint_idxs.append(movable_id) + + # check regular + if len(selected_joint_idxs) != joint_num or np.min(instance_mask) != 0 or np.max(instance_mask) != joint_num: + logger.info(f"{config_id} irregular") + env.scene.remove_articulation(env.object) + continue + + # check seen + if (instance_mask == (args.selected_part + 1)).sum() < 0.15 * instance_mask.shape[0]: + logger.info(f"{config_id} unseen") + env.scene.remove_articulation(env.object) + continue + + # check suitable + affordable_dist = np.linalg.norm(pcd_world - selected_affordable_position, axis=-1).min() + if affordable_dist > 0.03: + logger.info(f"{config_id} unsuitable") + env.scene.remove_articulation(env.object) + continue + + for _ in range(100): + env.step() + env.render() + if args.video: + frame, _ = cam.get_observation() + frame = (frame * 255).astype(np.uint8) + frame = Image.fromarray(frame) + video.append(frame) + end_time = time.time() + logger.info(f"===> added gt {end_time - start_time}") + + # prediction + if args.roartnet: + start_time = time.time() + pred_joint_bases, pred_joint_directions, pred_affordable_positions = roartnet_inference_fn(pcd_cam, pcd_color if has_rgb else None, shot_encoder, encoder, + denoise, normalize, resolution, receptive_field, sample_points_num, sample_tuples_num, tuple_more_num, + voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, bmm_size, joint_num, device=0) + pred_selected_joint_base = pred_joint_bases[args.selected_part] + pred_selected_joint_direction = pred_joint_directions[args.selected_part] + pred_selected_affordable_position = pred_affordable_positions[args.selected_part] + pred_selected_joint_base = transform_pc(pred_selected_joint_base[None, :], extrinsic)[0] + pred_selected_joint_direction = transform_dir(pred_selected_joint_direction[None, :], extrinsic)[0] + pred_selected_affordable_position = transform_pc(pred_selected_affordable_position[None, :], extrinsic)[0] + joint_translation_errors = calc_translation_error(pred_selected_joint_base, selected_joint_base, pred_selected_joint_direction, selected_joint_direction) + joint_direction_error = calc_direction_error(pred_selected_joint_direction, selected_joint_direction) + affordance_error, _, _, _, _ = calc_translation_error(pred_selected_affordable_position, selected_affordable_position, None, None) + selected_joint_base = pred_selected_joint_base + selected_joint_direction = pred_selected_joint_direction + selected_affordable_position = pred_selected_affordable_position + end_time = time.time() + logger.info(f"===> {config_id} roartnet predicted {end_time - start_time} {joint_translation_errors} {joint_direction_error} {affordance_error}") + + # obtain grasp + if args.graspnet: + start_time = time.time() + # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=None, lims=[ + # np.floor(np.min(pcd_grasp[:, 0])) - 0.1, np.ceil(np.max(pcd_grasp[:, 0])) + 0.1, + # np.floor(np.min(pcd_grasp[:, 1])) - 0.1, np.ceil(np.max(pcd_grasp[:, 1])) + 0.1, + # np.floor(np.min(pcd_grasp[:, 2])) - 0.1, np.ceil(np.max(pcd_grasp[:, 2])) + 0.1]) + # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=[-float('inf'), float('inf'), -float('inf'), float('inf'), -float('inf'), float('inf')]) + # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, apply_object_mask=True, dense_grasp=False, collision_detection=True) + try: + gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='fast') + except: + gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='slow') + if gg_grasp is None: + gg_grasp = [] + else: + if len(gg_grasp) != 2: + gg_grasp = [] + else: + gg_grasp, pcd_o3d = gg_grasp + gg_grasp = gg_grasp.nms().sort_by_score() + # if args.gui: + # grippers_o3d = gg_grasp.to_open3d_geometry_list() + # frame_o3d = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + # o3d.visualization.draw_geometries([*grippers_o3d, pcd_o3d, frame_o3d]) + end_time = time.time() + logger.info(f"===> {config_id} obtained grasp {end_time - start_time} {len(gg_grasp)}") + if len(gg_grasp) == 0: + logger.info(f"{config_id} no grasp detected") + env.scene.remove_articulation(env.object) + continue + + # add visuals + start_time = time.time() + grasp_scores, grasp_widths, grasp_depths, grasp_translations, grasp_rotations, grasp_invaffordances = [], [], [], [], [], [] + for g_idx, g_grasp in enumerate(gg_grasp): + grasp_score = g_grasp.score + grasp_scores.append(grasp_score) + grasp_width = g_grasp.width + grasp_widths.append(grasp_width) + grasp_depth = g_grasp.depth + grasp_depths.append(grasp_depth) + grasp_translation = g_grasp.translation + grasp_rotation = g_grasp.rotation_matrix + grasp_transformation = np.identity(4) + grasp_transformation[:3, :3] = grasp_rotation + grasp_transformation[:3, 3] = grasp_translation + grasp_transformation = extrinsic @ np.linalg.inv(c2c) @ grasp_transformation + grasp_translation = grasp_transformation[:3, 3] + grasp_translations.append(grasp_translation) + grasp_rotation = grasp_transformation[:3, :3] + grasp_rotations.append(grasp_rotation) + grasp_invaffordance = invaffordance_metrics(grasp_translation, grasp_rotation, grasp_score, selected_affordable_position, + selected_joint_base, selected_joint_direction, selected_joint_type) + grasp_invaffordances.append(grasp_invaffordance) + grasp_affordances = invaffordances2affordance(grasp_invaffordances) + if args.show_all_grasps: + for g_idx, g_grasp in enumerate(gg_grasp): + # env.add_grasp_visual(grasp_widths[g_idx], grasp_depths[g_idx], grasp_translations[g_idx], grasp_rotations[g_idx], affordance=(grasp_affordances[g_idx] - min(grasp_affordances)) / (max(grasp_affordances) - min(grasp_affordances)), name='grasp_{}'.format(g_idx)) + print("added grasp", grasp_affordances[g_idx]) + selected_grasp_idx = np.argmax(grasp_affordances) + selected_grasp_score = grasp_scores[selected_grasp_idx] + selected_grasp_width = grasp_widths[selected_grasp_idx] + selected_grasp_width = max(min(selected_grasp_width * 1.5, max_grasp_width * args.robot_scale), 0.0) + selected_grasp_depth = grasp_depths[selected_grasp_idx] + selected_grasp_translation = grasp_translations[selected_grasp_idx] + selected_grasp_rotation = grasp_rotations[selected_grasp_idx] + selected_grasp_affordance = grasp_affordances[selected_grasp_idx] + # env.add_grasp_visual(selected_grasp_width, selected_grasp_depth, selected_grasp_translation, selected_grasp_rotation, affordance=selected_grasp_affordance, name='selected_grasp') + selected_grasp_pose = np.identity(4) + selected_grasp_pose[:3, :3] = selected_grasp_rotation + selected_grasp_pose[:3, 3] = selected_grasp_translation + for _ in range(100): + env.step() + env.render() + if args.video: + frame, _ = cam.get_observation() + frame = (frame * 255).astype(np.uint8) + frame = Image.fromarray(frame) + video.append(frame) + end_time = time.time() + logger.info(f"===> {config_id} added visuals {end_time - start_time}") + + # load robot + if args.grasp: + start_time = time.time() + robot_material = env.get_material(4, 4, 0.01) + robot_move_steps_per_unit = 1000 + robot_short_wait_steps = 10 + robot_long_wait_steps = 1000 + robot = Robot(env, args.robot_urdf_path, robot_material, open_gripper=False, scale=args.robot_scale) + env.end_checking_contact(robot.hand_actor_id, robot.gripper_actor_ids, False) + selected_grasp_pre_pose = selected_grasp_pose.copy() + selected_grasp_pre_pose[:3, 3] -= 0.1 * selected_grasp_pre_pose[:3, 0] + robot.set_gripper(selected_grasp_width) + robot.set_pose(selected_grasp_pre_pose, selected_grasp_depth) + # env.add_point_visual(selected_grasp_pre_pose[:3, 3], color=[1, 0, 0], radius=0.01, name='selected_grasp_pre_translation') + for _ in range(100): + env.step() + env.render() + if args.video: + frame, _ = cam.get_observation() + frame = (frame * 255).astype(np.uint8) + frame = Image.fromarray(frame) + video.append(frame) + read_width = robot.get_gripper() + read_pose = robot.get_pose(selected_grasp_depth) + pose_error = calc_pose_error(selected_grasp_pre_pose, read_pose) + end_time = time.time() + logger.info(f"===> {config_id} loaded robot {end_time - start_time} {read_width - selected_grasp_width} {pose_error}") + + # grasp + start_time = time.time() + frames = robot.move_to_target_pose(selected_grasp_pose, selected_grasp_depth, robot_move_steps_per_unit * 10, vis_gif=args.video, cam=cam) + if args.video: + video.extend(frames) + frames = robot.wait_n_steps(robot_short_wait_steps, vis_gif=args.video, cam=cam) + if args.video: + video.extend(frames) + robot.close_gripper() + frames = robot.wait_n_steps(robot_long_wait_steps, vis_gif=args.video, cam=cam) + if args.video: + video.extend(frames) + current_width = robot.get_gripper() + success_grasp = env.check_contact_right() + read_pose = robot.get_pose(selected_grasp_depth) + pose_error = calc_pose_error(selected_grasp_pose, read_pose) + # env.add_point_visual(selected_grasp_pose[:3, 3], color=[0, 0, 1], radius=0.01, name='selected_grasp_translation') + # env.render() + # if args.video: + # frame, _ = cam.get_observation() + # frame = (frame * 255).astype(np.uint8) + # frame = Image.fromarray(frame) + # video.append(frame) + end_time = time.time() + logger.info(f"===> {config_id} grasped {end_time - start_time} {pose_error} {success_grasp}") + + # manipulation + if task_state != 0: + plan_trajectory, real_trajectory, semi_trajectory = [], [], [] + plan_current_pose = robot.get_pose(selected_grasp_depth) + joint_state_initial = env.get_target_part_state() + if selected_joint_type == 0: + joint_state_initial = joint_state_initial / np.pi * 180 + elif selected_joint_type == 1: + joint_state_initial = joint_state_initial * 100 + else: + raise ValueError(f"invalid joint type {selected_joint_type}") + joint_state_task = joint_state_initial - task_state * (1 - args.success_threshold) + start_time = time.time() + for step in tqdm.trange(int(np.ceil(np.abs(task_state / 2.0) * 1.5))): + current_pose = robot.get_pose(selected_grasp_depth) + if selected_joint_type == 0: + rotation_angle = -2.0 * np.sign(task_state) * selected_joint_re / 180.0 * np.pi + delta_pose = tf.rotation_matrix(angle=rotation_angle, direction=selected_joint_direction, point=selected_joint_base) + elif selected_joint_type == 1: + translation_distance = -2.0 * np.sign(task_state) / 100.0 + delta_pose = tf.translation_matrix(selected_joint_direction * translation_distance) + else: + raise ValueError(f"invalid joint type {selected_joint_type}") + # next_pose = delta_pose @ plan_current_pose + next_pose = delta_pose @ current_pose + frames = robot.move_to_target_pose(next_pose, selected_grasp_depth, robot_move_steps_per_unit * 2, vis_gif=args.video, cam=cam) + if args.video: + video.extend(frames) + # robot.wait_n_steps(robot_short_wait_steps) + read_pose = robot.get_pose(selected_grasp_depth) + pose_error = calc_pose_error(next_pose, read_pose) + current_joint_state = env.get_target_part_state() + if selected_joint_type == 0: + current_joint_state = current_joint_state / np.pi * 180 + elif selected_joint_type == 1: + current_joint_state = current_joint_state * 100 + else: + raise ValueError(f"invalid joint type {selected_joint_type}") + joint_state_error = current_joint_state - joint_state_task + logger.info(f"{pose_error} {joint_state_error}") + real_trajectory.append(read_pose) + semi_trajectory.append(next_pose) + plan_next_pose = delta_pose @ plan_current_pose + plan_current_pose = plan_next_pose.copy() + plan_trajectory.append(plan_next_pose) + success_manipulation = ((task_state < 0) and (joint_state_error > 0)) or ((task_state > 0) and (joint_state_error < 0)) + if success_manipulation: + break + # for step in tqdm.trange(int(np.ceil(np.abs(task_state / 2.0)))): + # if step % 2 != 0: + # continue + # env.add_grasp_visual(current_width, selected_grasp_depth, plan_trajectory[step][:3, 3], plan_trajectory[step][:3, :3], affordance=1, name='grasp_{}'.format(-step)) + # env.add_grasp_visual(current_width, selected_grasp_depth, real_trajectory[step][:3, 3], real_trajectory[step][:3, :3], affordance=0, name='grasp_{}_real'.format(-step)) + # env.add_grasp_visual(current_width, selected_grasp_depth, semi_trajectory[step][:3, 3], semi_trajectory[step][:3, :3], affordance=0.99, name='grasp_{}_semi'.format(-step)) + # env.render() + # if args.video: + # frame, _ = cam.get_observation() + # frame = (frame * 255).astype(np.uint8) + # frame = Image.fromarray(frame) + # video.append(frame) + end_time = time.time() + logger.info(f"===> {config_id} manipulated {end_time - start_time} {success_manipulation}") + if success_manipulation: + success_configs.append(config_id) + else: + fail_configs.append(config_id) + env.scene.remove_articulation(robot.robot) + env.scene.remove_articulation(env.object) + + if args.video: + imageio.mimsave(os.path.join(output_path, f'video_{str(config_id).zfill(2)}_{str(success_manipulation)}.mp4'), video) + + logger.info(f"===> success configs {success_configs} {len(success_configs)}") + logger.info(f"===> fail configs {fail_configs} {len(fail_configs)}") + logger.info(f"===> success rate {len(success_configs) / (len(success_configs) + len(fail_configs))}") diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..3e606d0 --- /dev/null +++ b/inference.py @@ -0,0 +1,331 @@ +from typing import Tuple, List, Optional +from omegaconf import OmegaConf +import os +import tqdm +import itertools +from itertools import combinations +import random +import numpy as np +import torch +import torch.nn as nn +import cupy as cp +# import cudf +# import cuml +from sklearn.cluster import KMeans +# from cuml.cluster import DBSCAN +# from cuml.common.device_selection import using_device_type +import MinkowskiEngine as ME +import open3d as o3d + +from models.roartnet import create_shot_encoder, create_encoder +from models.voting import ppf_kernel, rot_voting_kernel, ppf4d_kernel +from utilities.metrics_utils import calc_translation_error, calc_direction_error +from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting, visualize_confidence_voting +from utilities.data_utils import pc_normalize, farthest_point_sample, fibonacci_sphere +from utilities.env_utils import setup_seed +from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color +from src_shot.build import shot + + +def voting_translation(pc:np.ndarray, tr_offsets:np.ndarray, point_idxs:np.ndarray, confs:np.ndarray, + resolution:float, voting_num:int, device:int, + translation2pc:bool, multi_candidate:bool, candidate_threshold:float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # pc: (N, 3), tr_offsets: (N_t, 2), point_idxs: (N_t, 2), confs: (N_t,) + block_size = (tr_offsets.shape[0] + 512 - 1) // 512 + pc_min = np.min(pc, 0) + pc_max = np.max(pc, 0) + corner_min = pc_min - (pc_max - pc_min) + corner_max = pc_max + (pc_max - pc_min) + corners = np.stack([corner_min, corner_max]) + grid_res = ((corners[1] - corners[0]) / resolution).astype(np.int32) + 1 + + with cp.cuda.Device(device): + grid_obj = cp.asarray(np.zeros(grid_res, dtype=np.float32)) + + ppf_kernel( + (block_size, 1, 1), + (512, 1, 1), + ( + cp.ascontiguousarray(cp.asarray(pc).astype(cp.float32)), + cp.ascontiguousarray(cp.asarray(tr_offsets).astype(cp.float32)), + cp.ascontiguousarray(cp.asarray(confs).astype(cp.float32)), + cp.ascontiguousarray(cp.asarray(point_idxs).astype(cp.int32)), + grid_obj, + cp.ascontiguousarray(cp.asarray(corners[0]).astype(cp.float32)), + cp.float32(resolution), + cp.int32(tr_offsets.shape[0]), + cp.int32(voting_num), + cp.int32(grid_obj.shape[0]), + cp.int32(grid_obj.shape[1]), + cp.int32(grid_obj.shape[2]) + ) + ) + + if not multi_candidate: + cand = cp.array(cp.unravel_index(cp.array([cp.argmax(grid_obj, axis=None)]), grid_obj.shape)).T[::-1] + cand_world = cp.asarray(corners[0]) + cand * resolution + else: + indices = cp.indices(grid_obj.shape) + indices_list = cp.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape)) + votes_list = grid_obj.reshape(-1) + grid_pc = cp.asarray(corners[0]) + indices_list * resolution + normalized_votes_list = votes_list / cp.max(votes_list) + candidates = grid_pc[normalized_votes_list >= candidate_threshold] + candidate_weights = normalized_votes_list[normalized_votes_list >= candidate_threshold] + candidate_weights = candidate_weights / cp.sum(candidate_weights) + cand_world = cp.sum(candidates * candidate_weights[:, None], axis=0)[None, :] + + if translation2pc: + pc_cp = cp.asarray(pc) + best_idx = cp.linalg.norm(pc_cp - cand_world, axis=-1).argmin() + translation = pc_cp[best_idx] + else: + translation = cand_world[0] + + return (translation.get(), grid_obj.get(), corners) + +def voting_rotation(pc:np.ndarray, rot_offsets:np.ndarray, point_idxs:np.ndarray, confs:np.ndarray, + rot_candidate_num:int, angle_tol:float, voting_num:int, bmm_size:int, device:int, + multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool, kmeans:Optional[KMeans], + rotation_multi_neighbor:bool, neighbor_threshold:float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # pc: (N, 3), rot_offsets: (N_t,), point_idxs: (N_t, 2), confs: (N_t,) + block_size = (rot_offsets.shape[0] + 512 - 1) // 512 + sphere_pts = np.array(fibonacci_sphere(rot_candidate_num)) + expanded_confs = confs[:, None].repeat(voting_num, axis=-1).reshape(-1, 1) + + with cp.cuda.Device(device): + candidates = cp.zeros((rot_offsets.shape[0], voting_num, 3), cp.float32) + + rot_voting_kernel( + (block_size, 1, 1), + (512, 1, 1), + ( + cp.ascontiguousarray(cp.asarray(pc).astype(cp.float32)), + cp.ascontiguousarray(cp.asarray(rot_offsets).astype(cp.float32)), + candidates, + cp.ascontiguousarray(cp.asarray(point_idxs).astype(cp.int32)), + cp.int32(rot_offsets.shape[0]), + cp.int32(voting_num) + ) + ) + + candidates = candidates.get().reshape(-1, 3) + + with torch.no_grad(): + candidates = torch.from_numpy(candidates).cuda(device) + expanded_confs = torch.from_numpy(expanded_confs).cuda(device) + sph_cp = torch.tensor(sphere_pts.T, dtype=torch.float32).cuda(device) + counts = torch.zeros((sphere_pts.shape[0],), dtype=torch.float32).cuda(device) # (rot_candidate_num,) + + for i in range((candidates.shape[0] - 1) // bmm_size + 1): + cos = candidates[i * bmm_size:(i + 1) * bmm_size].mm(sph_cp) # (bmm_size, rot_candidate_num) + if not rotation_multi_neighbor: + voting = (cos > np.cos(2 * angle_tol / 180 * np.pi)).float() # (bmm_size, rot_candidate_num) + else: + # voting_indices = torch.topk(cos, neighbors_num, dim=-1)[1] + # voting_mask = torch.zeros_like(cos) + # voting_mask.scatter_(1, voting_indices, 1) + voting_mask = (cos > np.cos(neighbor_threshold / 180 * np.pi)).float() + voting = cos * voting_mask # (bmm_size, rot_candidate_num) + counts += torch.sum(voting * expanded_confs[i * bmm_size:(i + 1) * bmm_size], dim=0) + + if not multi_candidate: + direction = np.array(sphere_pts[np.argmax(counts.cpu().numpy())]) + else: + counts_list = counts.cpu().numpy() + normalized_counts_list = counts_list / np.max(counts_list) + candidates = sphere_pts[normalized_counts_list >= candidate_threshold] + candidate_weights = normalized_counts_list[normalized_counts_list >= candidate_threshold] + candidate_weights = candidate_weights / np.sum(candidate_weights) + if not rotation_cluster: + direction = np.sum(candidates * candidate_weights[:, None], axis=0) + direction /= np.linalg.norm(direction) + else: + if candidates.shape[0] == 1: + direction = candidates[0] + else: + kmeans.fit(candidates) + candidate_center1 = kmeans.cluster_centers_[0] + candidate_center2 = kmeans.cluster_centers_[1] + cluster_cos_theta = np.dot(candidate_center1, candidate_center2) + cluster_cos_theta = np.clip(cluster_cos_theta, -1., 1.) + cluster_theta = np.arccos(cluster_cos_theta) + if cluster_theta > np.pi/2: + candidate_clusters = kmeans.labels_ + clusters_num = np.bincount(candidate_clusters) + if clusters_num[0] == clusters_num[1]: + candidate_weights1 = candidate_weights[candidate_clusters == 0] + candidate_weights2 = candidate_weights[candidate_clusters == 1] + if np.sum(candidate_weights1) >= np.sum(candidate_weights2): + candidates = candidates[candidate_clusters == 0] + candidate_weights = candidate_weights[candidate_clusters == 0] + candidate_weights = candidate_weights / np.sum(candidate_weights) + direction = np.sum(candidates * candidate_weights[:, None], axis=0) + direction /= np.linalg.norm(direction) + else: + candidates = candidates[candidate_clusters == 1] + candidate_weights = candidate_weights[candidate_clusters == 1] + candidate_weights = candidate_weights / np.sum(candidate_weights) + direction = np.sum(candidates * candidate_weights[:, None], axis=0) + direction /= np.linalg.norm(direction) + else: + max_cluster = np.bincount(candidate_clusters).argmax() + candidates = candidates[candidate_clusters == max_cluster] + candidate_weights = candidate_weights[candidate_clusters == max_cluster] + candidate_weights = candidate_weights / np.sum(candidate_weights) + direction = np.sum(candidates * candidate_weights[:, None], axis=0) + direction /= np.linalg.norm(direction) + else: + direction = np.sum(candidates * candidate_weights[:, None], axis=0) + direction /= np.linalg.norm(direction) + + return (direction, sphere_pts, counts.cpu().numpy()) + + +def inference_fn(pc:np.ndarray, pc_color:Optional[np.ndarray], shot_encoder:nn.Module, encoder:nn.Module, + denoise:bool, normalize:str, resolution:float, receptive_field:int, sample_points_num:int, sample_tuples_num:int, tuple_more_num:int, + voting_num:int, rot_bin_num:int, angle_tol:float, + translation2pc:bool, multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool, + rotation_multi_neighbor:bool, neighbor_threshold:float, bmm_size:int, joint_num:int, device:int) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + if not hasattr(inference_fn, 'permutations'): + inference_fn.permutations = list(itertools.permutations(range(sample_points_num), 2)) + inference_fn.sample_points_num = sample_points_num + else: + if inference_fn.sample_points_num != sample_points_num: + inference_fn.permutations = list(itertools.permutations(range(sample_points_num), 2)) + inference_fn.sample_points_num = sample_points_num + else: + pass + if rotation_cluster: + # kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto') + kmeans = KMeans(n_clusters=2, init='k-means++', n_init=1) + else: + kmeans = None + rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi)) + has_rgb = pc_color is not None + + # preprocess point cloud + if denoise: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + _, index = pcd.remove_statistical_outlier(nb_neighbors=100, std_ratio=1.5) + pc = pc[index] + if has_rgb: + pc_color = pc_color[index] + + pc, center, scale = pc_normalize(pc, normalize) + + indices = ME.utils.sparse_quantize(np.ascontiguousarray(pc), return_index=True, quantization_size=resolution)[1] + pc = np.ascontiguousarray(pc[indices].astype(np.float32)) + if has_rgb: + pc_color = pc_color[indices] + + pc_normal = shot.estimate_normal(pc, resolution * receptive_field).reshape(-1, 3).astype(np.float32) + pc_normal[~np.isfinite(pc_normal)] = 0 + + pc_shot = shot.compute(pc, resolution * receptive_field, resolution * receptive_field).reshape(-1, 352).astype(np.float32) + pc_shot[~np.isfinite(pc_shot)] = 0 + + pc, indices = farthest_point_sample(pc, sample_points_num) + pc_normal = pc_normal[indices] + pc_shot = pc_shot[indices] + if has_rgb: + pc_color = pc_color[indices] + + point_idxs = random.sample(inference_fn.permutations, sample_tuples_num) + point_idxs = np.array(point_idxs, dtype=np.int64) + point_idxs_more = np.random.randint(0, sample_points_num, size=(sample_tuples_num, tuple_more_num), dtype=np.int64) + point_idxs_all = np.concatenate([point_idxs, point_idxs_more], axis=-1) + + pcs = torch.from_numpy(pc)[None, ...].cuda(device) + pc_normals = torch.from_numpy(pc_normal)[None, ...].cuda(device) + pc_shots = torch.from_numpy(pc_shot)[None, ...].cuda(device) + if has_rgb: + pc_colors = torch.from_numpy(pc_color)[None, ...].cuda(device) + point_idxs_alls = torch.from_numpy(point_idxs_all)[None, ...].cuda(device) + + # inference + with torch.no_grad(): + shot_feat = shot_encoder(pc_shots) # (1, N, N_s) + + shot_inputs = torch.cat([ + torch.gather(shot_feat, 1, + point_idxs_alls[:, :, i:i+1].expand( + (1, sample_tuples_num, shot_feat.shape[-1]))) + for i in range(point_idxs_alls.shape[-1])], dim=-1) # (1, N_t, N_s * (2 + N_m)) + normal_inputs = torch.cat([torch.max( + torch.sum(torch.gather(pc_normals, 1, + point_idxs_alls[:, :, i:i+1].expand( + (1, sample_tuples_num, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_alls[:, :, j:j+1].expand( + (1, sample_tuples_num, pc_normals.shape[-1]))), + dim=-1, keepdim=True), + torch.sum(-torch.gather(pc_normals, 1, + point_idxs_alls[:, :, i:i+1].expand( + (1, sample_tuples_num, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_alls[:, :, j:j+1].expand( + (1, sample_tuples_num, pc_normals.shape[-1]))), + dim=-1, keepdim=True)) + for (i, j) in combinations(np.arange(point_idxs_alls.shape[-1]), 2)], dim=-1) # (1, N_t, (2+N_m \choose 2)) + coord_inputs = torch.cat([ + torch.gather(pcs, 1, + point_idxs_alls[:, :, i:i+1].expand( + (1, sample_tuples_num, pcs.shape[-1]))) - + torch.gather(pcs, 1, + point_idxs_alls[:, :, j:j+1].expand( + (1, sample_tuples_num, pcs.shape[-1]))) + for (i, j) in combinations(np.arange(point_idxs_alls.shape[-1]), 2)], dim=-1) # (1, N_t, 3 * (2+N_m \choose 2)) + if has_rgb: + rgb_inputs = torch.cat([ + torch.gather(pc_colors, 1, + point_idxs_alls[:, :, i:i+1].expand( + (1, sample_tuples_num, pc_colors.shape[-1]))) + for i in range(point_idxs_alls.shape[-1])], dim=-1) # (1, N_t, 3 * (2 + N_m)) + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1) + else: + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1) + preds = encoder(inputs) # (1, N_t, (2 + N_r + 2 + 1) * J) + pred = preds.cpu().numpy().astype(np.float32)[0] # (N_t, (2 + N_r + 2 + 1) * J) + pred_tensor = torch.from_numpy(pred) + + pred_translations, pred_rotations, pred_affordances = [], [], [] + for j in range(joint_num): + # conf selection + pred_conf = torch.sigmoid(pred_tensor[:, -1*joint_num+j]) # (N_t,) + not_selected_indices = pred_conf < 0.5 + pred_conf[not_selected_indices] = 0 + # pred_conf[pred_conf > 0] = 1 + pred_conf = pred_conf.numpy() + + # translation voting + pred_tr = pred[:, 2*j:2*(j+1)] # (N_t, 2) + pred_translation, grid_obj, corners = voting_translation(pc, pred_tr, point_idxs, pred_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_translations.append(pred_translation) + + # rotation voting + pred_rot = pred_tensor[:, (2*joint_num+rot_bin_num*j):(2*joint_num+rot_bin_num*(j+1))] # (N_t, rot_bin_num) + pred_rot = torch.softmax(pred_rot, dim=-1) + pred_rot = torch.multinomial(pred_rot, 1).float()[:, 0] # (N_t,) + pred_rot = pred_rot / (rot_bin_num - 1) * np.pi + pred_rot = pred_rot.numpy() + pred_direction, sphere_pts, counts = voting_rotation(pc, pred_rot, point_idxs, pred_conf, + rot_candidate_num, angle_tol, voting_num, bmm_size, device, + multi_candidate, candidate_threshold, rotation_cluster, kmeans, + rotation_multi_neighbor, neighbor_threshold) + pred_rotations.append(pred_direction) + + # affordance voting + pred_afford = pred[:, (2*joint_num+rot_bin_num*joint_num+2*j):(2*joint_num+rot_bin_num*joint_num+2*(j+1))] # (N_t, 2) + pred_affordance, agrid_obj, acorners = voting_translation(pc, pred_afford, point_idxs, pred_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_affordances.append(pred_affordance) + return (pred_translations, pred_rotations, pred_affordances) + + +if __name__ == '__main__': + pass diff --git a/license/.gitkeep b/license/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/helper_math.cuh b/models/helper_math.cuh new file mode 100644 index 0000000..4597d8a --- /dev/null +++ b/models/helper_math.cuh @@ -0,0 +1,1450 @@ +#pragma once +/** +* Copyright 1993-2012 NVIDIA Corporation. All rights reserved. +* +* Please refer to the NVIDIA end user license agreement (EULA) associated +* with this source code for terms and conditions that govern your use of +* this software. Any use, reproduction, disclosure, or distribution of +* this software and related documentation outside the terms of the EULA +* is strictly prohibited. +* +*/ + +/* +* This file implements common mathematical operations on vector types +* (float3, float4 etc.) since these are not provided as standard by CUDA. +* +* The syntax is modeled on the Cg standard library. +* +* This is part of the Helper library includes +* +* Thanks to Linh Hah for additions and fixes. +*/ + +#ifndef HELPER_MATH_H +#define HELPER_MATH_H + +#include "/usr/local/cuda/include/cuda_runtime.h" + +typedef unsigned int uint; +typedef unsigned short ushort; + +#ifndef __CUDACC__ +#include + +//////////////////////////////////////////////////////////////////////////////// +// host implementations of CUDA functions +//////////////////////////////////////////////////////////////////////////////// + +inline float fminf(float a, float b) +{ + return a < b ? a : b; +} + +inline float fmaxf(float a, float b) +{ + return a > b ? a : b; +} + +inline int max(int a, int b) +{ + return a > b ? a : b; +} + +inline int min(int a, int b) +{ + return a < b ? a : b; +} + +inline float rsqrtf(float x) +{ + return 1.0f / sqrtf(x); +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +// constructors +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 make_float2(float s) +{ + return make_float2(s, s); +} +inline __host__ __device__ float2 make_float2(float3 a) +{ + return make_float2(a.x, a.y); +} +inline __host__ __device__ float2 make_float2(int2 a) +{ + return make_float2(float(a.x), float(a.y)); +} +inline __host__ __device__ float2 make_float2(uint2 a) +{ + return make_float2(float(a.x), float(a.y)); +} + +inline __host__ __device__ int2 make_int2(int s) +{ + return make_int2(s, s); +} +inline __host__ __device__ int2 make_int2(int3 a) +{ + return make_int2(a.x, a.y); +} +inline __host__ __device__ int2 make_int2(uint2 a) +{ + return make_int2(int(a.x), int(a.y)); +} +inline __host__ __device__ int2 make_int2(float2 a) +{ + return make_int2(int(a.x), int(a.y)); +} + +inline __host__ __device__ uint2 make_uint2(uint s) +{ + return make_uint2(s, s); +} +inline __host__ __device__ uint2 make_uint2(uint3 a) +{ + return make_uint2(a.x, a.y); +} +inline __host__ __device__ uint2 make_uint2(int2 a) +{ + return make_uint2(uint(a.x), uint(a.y)); +} + +inline __host__ __device__ float3 make_float3(float s) +{ + return make_float3(s, s, s); +} +inline __host__ __device__ float3 make_float3(float2 a) +{ + return make_float3(a.x, a.y, 0.0f); +} +inline __host__ __device__ float3 make_float3(float2 a, float s) +{ + return make_float3(a.x, a.y, s); +} +inline __host__ __device__ float3 make_float3(float4 a) +{ + return make_float3(a.x, a.y, a.z); +} +inline __host__ __device__ float3 make_float3(int3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} +inline __host__ __device__ float3 make_float3(uint3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} + +inline __host__ __device__ int3 make_int3(int s) +{ + return make_int3(s, s, s); +} +inline __host__ __device__ int3 make_int3(int2 a) +{ + return make_int3(a.x, a.y, 0); +} +inline __host__ __device__ int3 make_int3(int2 a, int s) +{ + return make_int3(a.x, a.y, s); +} +inline __host__ __device__ int3 make_int3(uint3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} +inline __host__ __device__ int3 make_int3(float3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} + +inline __host__ __device__ uint3 make_uint3(uint s) +{ + return make_uint3(s, s, s); +} +inline __host__ __device__ uint3 make_uint3(uint2 a) +{ + return make_uint3(a.x, a.y, 0); +} +inline __host__ __device__ uint3 make_uint3(uint2 a, uint s) +{ + return make_uint3(a.x, a.y, s); +} +inline __host__ __device__ uint3 make_uint3(uint4 a) +{ + return make_uint3(a.x, a.y, a.z); +} +inline __host__ __device__ uint3 make_uint3(int3 a) +{ + return make_uint3(uint(a.x), uint(a.y), uint(a.z)); +} + +inline __host__ __device__ float4 make_float4(float s) +{ + return make_float4(s, s, s, s); +} +inline __host__ __device__ float4 make_float4(float3 a) +{ + return make_float4(a.x, a.y, a.z, 0.0f); +} +inline __host__ __device__ float4 make_float4(float3 a, float w) +{ + return make_float4(a.x, a.y, a.z, w); +} +inline __host__ __device__ float4 make_float4(int4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} +inline __host__ __device__ float4 make_float4(uint4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} + +inline __host__ __device__ int4 make_int4(int s) +{ + return make_int4(s, s, s, s); +} +inline __host__ __device__ int4 make_int4(int3 a) +{ + return make_int4(a.x, a.y, a.z, 0); +} +inline __host__ __device__ int4 make_int4(int3 a, int w) +{ + return make_int4(a.x, a.y, a.z, w); +} +inline __host__ __device__ int4 make_int4(uint4 a) +{ + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); +} +inline __host__ __device__ int4 make_int4(float4 a) +{ + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); +} + + +inline __host__ __device__ uint4 make_uint4(uint s) +{ + return make_uint4(s, s, s, s); +} +inline __host__ __device__ uint4 make_uint4(uint3 a) +{ + return make_uint4(a.x, a.y, a.z, 0); +} +inline __host__ __device__ uint4 make_uint4(uint3 a, uint w) +{ + return make_uint4(a.x, a.y, a.z, w); +} +inline __host__ __device__ uint4 make_uint4(int4 a) +{ + return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// negate +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator-(float2 &a) +{ + return make_float2(-a.x, -a.y); +} +inline __host__ __device__ int2 operator-(int2 &a) +{ + return make_int2(-a.x, -a.y); +} +inline __host__ __device__ float3 operator-(float3 &a) +{ + return make_float3(-a.x, -a.y, -a.z); +} +inline __host__ __device__ int3 operator-(int3 &a) +{ + return make_int3(-a.x, -a.y, -a.z); +} +inline __host__ __device__ float4 operator-(float4 &a) +{ + return make_float4(-a.x, -a.y, -a.z, -a.w); +} +inline __host__ __device__ int4 operator-(int4 &a) +{ + return make_int4(-a.x, -a.y, -a.z, -a.w); +} + +//////////////////////////////////////////////////////////////////////////////// +// addition +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator+(float2 a, float2 b) +{ + return make_float2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(float2 &a, float2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ float2 operator+(float2 a, float b) +{ + return make_float2(a.x + b, a.y + b); +} +inline __host__ __device__ float2 operator+(float b, float2 a) +{ + return make_float2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(float2 &a, float b) +{ + a.x += b; + a.y += b; +} + +inline __host__ __device__ int2 operator+(int2 a, int2 b) +{ + return make_int2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(int2 &a, int2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ int2 operator+(int2 a, int b) +{ + return make_int2(a.x + b, a.y + b); +} +inline __host__ __device__ int2 operator+(int b, int2 a) +{ + return make_int2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(int2 &a, int b) +{ + a.x += b; + a.y += b; +} + +inline __host__ __device__ uint2 operator+(uint2 a, uint2 b) +{ + return make_uint2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(uint2 &a, uint2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ uint2 operator+(uint2 a, uint b) +{ + return make_uint2(a.x + b, a.y + b); +} +inline __host__ __device__ uint2 operator+(uint b, uint2 a) +{ + return make_uint2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(uint2 &a, uint b) +{ + a.x += b; + a.y += b; +} + + +inline __host__ __device__ float3 operator+(float3 a, float3 b) +{ + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(float3 &a, float3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ float3 operator+(float3 a, float b) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(float3 &a, float b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ int3 operator+(int3 a, int3 b) +{ + return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(int3 &a, int3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ int3 operator+(int3 a, int b) +{ + return make_int3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(int3 &a, int b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ uint3 operator+(uint3 a, uint3 b) +{ + return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(uint3 &a, uint3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ uint3 operator+(uint3 a, uint b) +{ + return make_uint3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(uint3 &a, uint b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ int3 operator+(int b, int3 a) +{ + return make_int3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ uint3 operator+(uint b, uint3 a) +{ + return make_uint3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ float3 operator+(float b, float3 a) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} + +inline __host__ __device__ float4 operator+(float4 a, float4 b) +{ + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(float4 &a, float4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ float4 operator+(float4 a, float b) +{ + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ float4 operator+(float b, float4 a) +{ + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(float4 &a, float b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +inline __host__ __device__ int4 operator+(int4 a, int4 b) +{ + return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(int4 &a, int4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ int4 operator+(int4 a, int b) +{ + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ int4 operator+(int b, int4 a) +{ + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(int4 &a, int b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +inline __host__ __device__ uint4 operator+(uint4 a, uint4 b) +{ + return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(uint4 &a, uint4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ uint4 operator+(uint4 a, uint b) +{ + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ uint4 operator+(uint b, uint4 a) +{ + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(uint4 &a, uint b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +//////////////////////////////////////////////////////////////////////////////// +// subtract +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator-(float2 a, float2 b) +{ + return make_float2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(float2 &a, float2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ float2 operator-(float2 a, float b) +{ + return make_float2(a.x - b, a.y - b); +} +inline __host__ __device__ float2 operator-(float b, float2 a) +{ + return make_float2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(float2 &a, float b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ int2 operator-(int2 a, int2 b) +{ + return make_int2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(int2 &a, int2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ int2 operator-(int2 a, int b) +{ + return make_int2(a.x - b, a.y - b); +} +inline __host__ __device__ int2 operator-(int b, int2 a) +{ + return make_int2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(int2 &a, int b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ uint2 operator-(uint2 a, uint2 b) +{ + return make_uint2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(uint2 &a, uint2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ uint2 operator-(uint2 a, uint b) +{ + return make_uint2(a.x - b, a.y - b); +} +inline __host__ __device__ uint2 operator-(uint b, uint2 a) +{ + return make_uint2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(uint2 &a, uint b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ float3 operator-(float3 a, float3 b) +{ + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(float3 &a, float3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ float3 operator-(float3 a, float b) +{ + return make_float3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ float3 operator-(float b, float3 a) +{ + return make_float3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(float3 &a, float b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ int3 operator-(int3 a, int3 b) +{ + return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(int3 &a, int3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ int3 operator-(int3 a, int b) +{ + return make_int3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ int3 operator-(int b, int3 a) +{ + return make_int3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(int3 &a, int b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ uint3 operator-(uint3 a, uint3 b) +{ + return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(uint3 &a, uint3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ uint3 operator-(uint3 a, uint b) +{ + return make_uint3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ uint3 operator-(uint b, uint3 a) +{ + return make_uint3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(uint3 &a, uint b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ float4 operator-(float4 a, float4 b) +{ + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(float4 &a, float4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ float4 operator-(float4 a, float b) +{ + return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ void operator-=(float4 &a, float b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +inline __host__ __device__ int4 operator-(int4 a, int4 b) +{ + return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(int4 &a, int4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ int4 operator-(int4 a, int b) +{ + return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ int4 operator-(int b, int4 a) +{ + return make_int4(b - a.x, b - a.y, b - a.z, b - a.w); +} +inline __host__ __device__ void operator-=(int4 &a, int b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +inline __host__ __device__ uint4 operator-(uint4 a, uint4 b) +{ + return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(uint4 &a, uint4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ uint4 operator-(uint4 a, uint b) +{ + return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ uint4 operator-(uint b, uint4 a) +{ + return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w); +} +inline __host__ __device__ void operator-=(uint4 &a, uint b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +//////////////////////////////////////////////////////////////////////////////// +// multiply +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator*(float2 a, float2 b) +{ + return make_float2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(float2 &a, float2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ float2 operator*(float2 a, float b) +{ + return make_float2(a.x * b, a.y * b); +} +inline __host__ __device__ float2 operator*(float b, float2 a) +{ + return make_float2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(float2 &a, float b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ int2 operator*(int2 a, int2 b) +{ + return make_int2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(int2 &a, int2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ int2 operator*(int2 a, int b) +{ + return make_int2(a.x * b, a.y * b); +} +inline __host__ __device__ int2 operator*(int b, int2 a) +{ + return make_int2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(int2 &a, int b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ uint2 operator*(uint2 a, uint2 b) +{ + return make_uint2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(uint2 &a, uint2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ uint2 operator*(uint2 a, uint b) +{ + return make_uint2(a.x * b, a.y * b); +} +inline __host__ __device__ uint2 operator*(uint b, uint2 a) +{ + return make_uint2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(uint2 &a, uint b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ float3 operator*(float3 a, float3 b) +{ + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(float3 &a, float3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ float3 operator*(float3 a, float b) +{ + return make_float3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ float3 operator*(float b, float3 a) +{ + return make_float3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(float3 &a, float b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ int3 operator*(int3 a, int3 b) +{ + return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(int3 &a, int3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ int3 operator*(int3 a, int b) +{ + return make_int3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ int3 operator*(int b, int3 a) +{ + return make_int3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(int3 &a, int b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ uint3 operator*(uint3 a, uint3 b) +{ + return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(uint3 &a, uint3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ uint3 operator*(uint3 a, uint b) +{ + return make_uint3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ uint3 operator*(uint b, uint3 a) +{ + return make_uint3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(uint3 &a, uint b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ float4 operator*(float4 a, float4 b) +{ + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(float4 &a, float4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ float4 operator*(float4 a, float b) +{ + return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ float4 operator*(float b, float4 a) +{ + return make_float4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(float4 &a, float b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +inline __host__ __device__ int4 operator*(int4 a, int4 b) +{ + return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(int4 &a, int4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ int4 operator*(int4 a, int b) +{ + return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ int4 operator*(int b, int4 a) +{ + return make_int4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(int4 &a, int b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +inline __host__ __device__ uint4 operator*(uint4 a, uint4 b) +{ + return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(uint4 &a, uint4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ uint4 operator*(uint4 a, uint b) +{ + return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ uint4 operator*(uint b, uint4 a) +{ + return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(uint4 &a, uint b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +//////////////////////////////////////////////////////////////////////////////// +// divide +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator/(float2 a, float2 b) +{ + return make_float2(a.x / b.x, a.y / b.y); +} +inline __host__ __device__ void operator/=(float2 &a, float2 b) +{ + a.x /= b.x; + a.y /= b.y; +} +inline __host__ __device__ float2 operator/(float2 a, float b) +{ + return make_float2(a.x / b, a.y / b); +} +inline __host__ __device__ void operator/=(float2 &a, float b) +{ + a.x /= b; + a.y /= b; +} +inline __host__ __device__ float2 operator/(float b, float2 a) +{ + return make_float2(b / a.x, b / a.y); +} + +inline __host__ __device__ float3 operator/(float3 a, float3 b) +{ + return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ void operator/=(float3 &a, float3 b) +{ + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; +} +inline __host__ __device__ float3 operator/(float3 a, float b) +{ + return make_float3(a.x / b, a.y / b, a.z / b); +} +inline __host__ __device__ void operator/=(float3 &a, float b) +{ + a.x /= b; + a.y /= b; + a.z /= b; +} +inline __host__ __device__ float3 operator/(float b, float3 a) +{ + return make_float3(b / a.x, b / a.y, b / a.z); +} + +inline __host__ __device__ float4 operator/(float4 a, float4 b) +{ + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); +} +inline __host__ __device__ void operator/=(float4 &a, float4 b) +{ + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; + a.w /= b.w; +} +inline __host__ __device__ float4 operator/(float4 a, float b) +{ + return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); +} +inline __host__ __device__ void operator/=(float4 &a, float b) +{ + a.x /= b; + a.y /= b; + a.z /= b; + a.w /= b; +} +inline __host__ __device__ float4 operator/(float b, float4 a) +{ + return make_float4(b / a.x, b / a.y, b / a.z, b / a.w); +} + +//////////////////////////////////////////////////////////////////////////////// +// min +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fminf(float2 a, float2 b) +{ + return make_float2(fminf(a.x, b.x), fminf(a.y, b.y)); +} +inline __host__ __device__ float3 fminf(float3 a, float3 b) +{ + return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z)); +} +inline __host__ __device__ float4 fminf(float4 a, float4 b) +{ + return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w)); +} + +inline __host__ __device__ int2 min(int2 a, int2 b) +{ + return make_int2(min(a.x, b.x), min(a.y, b.y)); +} +inline __host__ __device__ int3 min(int3 a, int3 b) +{ + return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); +} +inline __host__ __device__ int4 min(int4 a, int4 b) +{ + return make_int4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); +} + +inline __host__ __device__ uint2 min(uint2 a, uint2 b) +{ + return make_uint2(min(a.x, b.x), min(a.y, b.y)); +} +inline __host__ __device__ uint3 min(uint3 a, uint3 b) +{ + return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); +} +inline __host__ __device__ uint4 min(uint4 a, uint4 b) +{ + return make_uint4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// max +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fmaxf(float2 a, float2 b) +{ + return make_float2(fmaxf(a.x, b.x), fmaxf(a.y, b.y)); +} +inline __host__ __device__ float3 fmaxf(float3 a, float3 b) +{ + return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z)); +} +inline __host__ __device__ float4 fmaxf(float4 a, float4 b) +{ + return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w)); +} + +inline __host__ __device__ int2 max(int2 a, int2 b) +{ + return make_int2(max(a.x, b.x), max(a.y, b.y)); +} +inline __host__ __device__ int3 max(int3 a, int3 b) +{ + return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); +} +inline __host__ __device__ int4 max(int4 a, int4 b) +{ + return make_int4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); +} + +inline __host__ __device__ uint2 max(uint2 a, uint2 b) +{ + return make_uint2(max(a.x, b.x), max(a.y, b.y)); +} +inline __host__ __device__ uint3 max(uint3 a, uint3 b) +{ + return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); +} +inline __host__ __device__ uint4 max(uint4 a, uint4 b) +{ + return make_uint4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// lerp +// - linear interpolation between a and b, based on value t in [0, 1] range +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float lerp(float a, float b, float t) +{ + return a + t*(b - a); +} +inline __device__ __host__ float2 lerp(float2 a, float2 b, float t) +{ + return a + t*(b - a); +} +inline __device__ __host__ float3 lerp(float3 a, float3 b, float t) +{ + return a + t*(b - a); +} +inline __device__ __host__ float4 lerp(float4 a, float4 b, float t) +{ + return a + t*(b - a); +} + +//////////////////////////////////////////////////////////////////////////////// +// clamp +// - clamp the value v to be in the range [a, b] +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float clamp(float f, float a, float b) +{ + return fmaxf(a, fminf(f, b)); +} +inline __device__ __host__ int clamp(int f, int a, int b) +{ + return max(a, min(f, b)); +} +inline __device__ __host__ uint clamp(uint f, uint a, uint b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ float2 clamp(float2 v, float a, float b) +{ + return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b) +{ + return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ float3 clamp(float3 v, float a, float b) +{ + return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) +{ + return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ float4 clamp(float4 v, float a, float b) +{ + return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b) +{ + return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +inline __device__ __host__ int2 clamp(int2 v, int a, int b) +{ + return make_int2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b) +{ + return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ int3 clamp(int3 v, int a, int b) +{ + return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b) +{ + return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ int4 clamp(int4 v, int a, int b) +{ + return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b) +{ + return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b) +{ + return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b) +{ + return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b) +{ + return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b) +{ + return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b) +{ + return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b) +{ + return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// dot product +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float dot(float2 a, float2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ float dot(float3 a, float3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ float dot(float4 a, float4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +inline __host__ __device__ int dot(int2 a, int2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ int dot(int3 a, int3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ int dot(int4 a, int4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +inline __host__ __device__ uint dot(uint2 a, uint2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ uint dot(uint3 a, uint3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ uint dot(uint4 a, uint4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +//////////////////////////////////////////////////////////////////////////////// +// length +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float length(float2 v) +{ + return sqrtf(dot(v, v)); +} +inline __host__ __device__ float length(float3 v) +{ + return sqrtf(dot(v, v)); +} +inline __host__ __device__ float length(float4 v) +{ + return sqrtf(dot(v, v)); +} + +//////////////////////////////////////////////////////////////////////////////// +// normalize +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 normalize(float2 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} +inline __host__ __device__ float3 normalize(float3 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} +inline __host__ __device__ float4 normalize(float4 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +//////////////////////////////////////////////////////////////////////////////// +// floor +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 floorf(float2 v) +{ + return make_float2(floorf(v.x), floorf(v.y)); +} +inline __host__ __device__ float3 floorf(float3 v) +{ + return make_float3(floorf(v.x), floorf(v.y), floorf(v.z)); +} +inline __host__ __device__ float4 floorf(float4 v) +{ + return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// frac - returns the fractional portion of a scalar or each vector component +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float fracf(float v) +{ + return v - floorf(v); +} +inline __host__ __device__ float2 fracf(float2 v) +{ + return make_float2(fracf(v.x), fracf(v.y)); +} +inline __host__ __device__ float3 fracf(float3 v) +{ + return make_float3(fracf(v.x), fracf(v.y), fracf(v.z)); +} +inline __host__ __device__ float4 fracf(float4 v) +{ + return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// fmod +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fmodf(float2 a, float2 b) +{ + return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y)); +} +inline __host__ __device__ float3 fmodf(float3 a, float3 b) +{ + return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z)); +} +inline __host__ __device__ float4 fmodf(float4 a, float4 b) +{ + return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// absolute value +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fabs(float2 v) +{ + return make_float2(fabs(v.x), fabs(v.y)); +} +inline __host__ __device__ float3 fabs(float3 v) +{ + return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); +} +inline __host__ __device__ float4 fabs(float4 v) +{ + return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); +} + +inline __host__ __device__ int2 abs(int2 v) +{ + return make_int2(abs(v.x), abs(v.y)); +} +inline __host__ __device__ int3 abs(int3 v) +{ + return make_int3(abs(v.x), abs(v.y), abs(v.z)); +} +inline __host__ __device__ int4 abs(int4 v) +{ + return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// reflect +// - returns reflection of incident ray I around surface normal N +// - N should be normalized, reflected vector's length is equal to length of I +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float3 reflect(float3 i, float3 n) +{ + return i - 2.0f * n * dot(n, i); +} + +//////////////////////////////////////////////////////////////////////////////// +// cross product +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float3 cross(float3 a, float3 b) +{ + return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); +} + +//////////////////////////////////////////////////////////////////////////////// +// smoothstep +// - returns 0 if x < a +// - returns 1 if x > b +// - otherwise returns smooth interpolation between 0 and 1 based on x +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float smoothstep(float a, float b, float x) +{ + float y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(3.0f - (2.0f*y))); +} +inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x) +{ + float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y))); +} +inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) +{ + float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y))); +} +inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x) +{ + float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y))); +} + +#endif \ No newline at end of file diff --git a/models/roartnet.py b/models/roartnet.py new file mode 100644 index 0000000..d96c329 --- /dev/null +++ b/models/roartnet.py @@ -0,0 +1,97 @@ +from typing import List +from itertools import combinations +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +class ResLayer(nn.Module): + def __init__(self, dim_in:int, dim_out:int, bn:bool=False, ln:bool=False, dropout:float=0.): + super().__init__() + self.is_bn = bn + self.is_ln = ln + self.fc1 = nn.Linear(dim_in, dim_out) + if bn: + self.bn1 = nn.BatchNorm1d(dim_out) + else: + self.bn1 = lambda x: x + if ln: + self.ln1 = nn.LayerNorm(dim_out) + else: + self.ln1 = lambda x: x + self.fc2 = nn.Linear(dim_out, dim_out) + if bn: + self.bn2 = nn.BatchNorm1d(dim_out) + else: + self.bn2 = lambda x: x + if ln: + self.ln2 = nn.LayerNorm(dim_out) + else: + self.ln2 = lambda x: x + if dim_in != dim_out: + self.fc0 = nn.Linear(dim_in, dim_out) + else: + self.fc0 = None + self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() + + def forward(self, x): + x_res = x if self.fc0 is None else self.fc0(x) + x = self.fc1(x) + if len(x.shape) > 3 or len(x.shape) < 2: + raise ValueError("x.shape should be (B, N, D) or (N, D)") + elif len(x.shape) == 3 and self.is_bn: + x = x.permute(0, 2, 1) # from (B, N, D) to (B, D, N) + x = self.bn1(x) + x = x.permute(0, 2, 1) # from (B, D, N) to (B, N, D) + elif len(x.shape) == 2 and self.is_bn: + x = self.bn1(x) + elif self.is_ln: + x = self.ln1(x) + else: + x = self.bn1(x) # actually self.bn1 is identity function + x = F.relu(x) + + x = self.fc2(x) + if len(x.shape) > 3 or len(x.shape) < 2: + raise ValueError("x.shape should be (B, N, D) or (N, D)") + elif len(x.shape) == 3 and self.is_bn: + x = x.permute(0, 2, 1) # from (B, N, D) to (B, D, N) + x = self.bn2(x) + x = x.permute(0, 2, 1) # from (B, D, N) to (B, N, D) + elif len(x.shape) == 2 and self.is_bn: + x = self.bn2(x) + elif self.is_ln: + x = self.ln2(x) + else: + x = self.bn2(x) # actually self.bn2 is identity function + x = self.dropout(x + x_res) + return x + + +def create_MLP(input_dim:int, hidden_dims:List[int], output_dim:int, + bn:bool, ln:bool, dropout:float) -> nn.Module: + fcs = hidden_dims + fcs.insert(0, input_dim) + fcs.append(output_dim) + MLP = nn.Sequential( + *[ResLayer(fcs[i], fcs[i+1], bn=bn, ln=ln, dropout=dropout) + for i in range(len(fcs) - 1)] + ) + return MLP + + +def create_shot_encoder(shot_hidden_dims:List[int], shot_feature_dim:int, + shot_bn:bool, shot_ln:bool, shot_dropout:float) -> nn.Module: + return create_MLP(input_dim=352, hidden_dims=shot_hidden_dims, output_dim=shot_feature_dim, + bn=shot_bn, ln=shot_ln, dropout=shot_dropout) + +def create_encoder(num_more:int, shot_feature_dim:int, has_rgb:bool, + overall_hidden_dims:List[int], rot_bin_num:int, joint_num:int, + overall_bn:bool, overall_ln:bool, overall_dropout:float) -> nn.Module: + # input order: (coords, normals, shots(, rgb)) + overall_input_dim = len(list(combinations(np.arange(num_more + 2), 2))) * 4 + (num_more + 2) * shot_feature_dim + (3 * (num_more + 2) if has_rgb else 0) + # output order: (J*tr, J*rot, J*afford, J*conf) + overall_output_dim = (2 + rot_bin_num + 2 + 1) * joint_num + + return create_MLP(input_dim=overall_input_dim, hidden_dims=overall_hidden_dims, output_dim=overall_output_dim, + bn=overall_bn, ln=overall_ln, dropout=overall_dropout) diff --git a/models/voting.py b/models/voting.py new file mode 100644 index 0000000..3fed7cf --- /dev/null +++ b/models/voting.py @@ -0,0 +1,321 @@ +""" +Modified from https://github.com/qq456cvb/CPPF/blob/main/models/voting.py +""" +import os +import cupy as cp + + +helper_math_path = os.path.join(os.path.dirname(__file__), 'helper_math.cuh') + + +ppf_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r''' + #define M_PI 3.14159265358979323846264338327950288 + extern "C" __global__ + void ppf_voting( + const float *points, const float *outputs, const float *probs, const int *point_idxs, float *grid_obj, const float *corner, const float res, + int n_ppfs, int n_rots, int grid_x, int grid_y, int grid_z + ) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_ppfs) { + float proj_len = outputs[idx * 2]; + float odist = outputs[idx * 2 + 1]; + if (odist < res) return; + int a_idx = point_idxs[idx * 2]; + int b_idx = point_idxs[idx * 2 + 1]; + float3 a = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]); + float3 b = make_float3(points[b_idx * 3], points[b_idx * 3 + 1], points[b_idx * 3 + 2]); + float3 ab = a - b; + if (length(ab) < 1e-7) return; + ab /= (length(ab) + 1e-7); + float3 c = a - ab * proj_len; + + // float prob = max(probs[a_idx], probs[b_idx]); + float prob = probs[idx]; + float3 co = make_float3(0.f, -ab.z, ab.y); + if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f); + float3 x = co / (length(co) + 1e-7) * odist; + float3 y = cross(x, ab); + int adaptive_n_rots = min(int(odist / res * (2 * M_PI)), n_rots); + // int adaptive_n_rots = n_rots; + for (int i = 0; i < adaptive_n_rots; i++) { + float angle = i * 2 * M_PI / adaptive_n_rots; + float3 offset = cos(angle) * x + sin(angle) * y; + float3 center_grid = (c + offset - make_float3(corner[0], corner[1], corner[2])) / res; + if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 || + center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) { + continue; + } + int3 center_grid_floor = make_int3(center_grid); + int3 center_grid_ceil = center_grid_floor + 1; + float3 residual = fracf(center_grid); + + float3 w0 = 1.f - residual; + float3 w1 = residual; + + float lll = w0.x * w0.y * w0.z; + float llh = w0.x * w0.y * w1.z; + float lhl = w0.x * w1.y * w0.z; + float lhh = w0.x * w1.y * w1.z; + float hll = w1.x * w0.y * w0.z; + float hlh = w1.x * w0.y * w1.z; + float hhl = w1.x * w1.y * w0.z; + float hhh = w1.x * w1.y * w1.z; + + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], lll * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], llh * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], lhl * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], lhh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], hll * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], hlh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], hhl * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], hhh * prob); + } + } + } +''', 'ppf_voting') + +ppf_retrieval_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r''' + #define M_PI 3.14159265358979323846264338327950288 + extern "C" __global__ + void ppf_voting_retrieval( + const float *point_pairs, const float *outputs, const float *probs, float *grid_obj, const float *corner, const float res, + int n_ppfs, int n_rots, int grid_x, int grid_y, int grid_z + ) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_ppfs) { + float proj_len = outputs[idx * 2]; + float odist = outputs[idx * 2 + 1]; + if (odist < res) return; + float3 a = make_float3(point_pairs[idx * 6], point_pairs[idx * 6 + 1], point_pairs[idx * 6 + 2]); + float3 b = make_float3(point_pairs[idx * 6 + 3], point_pairs[idx * 6 + 4], point_pairs[idx * 6 + 5]); + float3 ab = a - b; + if (length(ab) < 1e-7) return; + ab /= (length(ab) + 1e-7); + float3 c = a - ab * proj_len; + + // float prob = max(probs[a_idx], probs[b_idx]); + float prob = probs[idx]; + float3 co = make_float3(0.f, -ab.z, ab.y); + if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f); + float3 x = co / (length(co) + 1e-7) * odist; + float3 y = cross(x, ab); + int adaptive_n_rots = min(int(odist / res * (2 * M_PI)), n_rots); + // int adaptive_n_rots = n_rots; + for (int i = 0; i < adaptive_n_rots; i++) { + float angle = i * 2 * M_PI / adaptive_n_rots; + float3 offset = cos(angle) * x + sin(angle) * y; + float3 center_grid = (c + offset - make_float3(corner[0], corner[1], corner[2])) / res; + if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 || + center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) { + continue; + } + int3 center_grid_floor = make_int3(center_grid); + int3 center_grid_ceil = center_grid_floor + 1; + float3 residual = fracf(center_grid); + + float3 w0 = 1.f - residual; + float3 w1 = residual; + + float lll = w0.x * w0.y * w0.z; + float llh = w0.x * w0.y * w1.z; + float lhl = w0.x * w1.y * w0.z; + float lhh = w0.x * w1.y * w1.z; + float hll = w1.x * w0.y * w0.z; + float hlh = w1.x * w0.y * w1.z; + float hhl = w1.x * w1.y * w0.z; + float hhh = w1.x * w1.y * w1.z; + + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], lll * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], llh * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], lhl * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], lhh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], hll * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], hlh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], hhl * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], hhh * prob); + } + } + } +''', 'ppf_voting_retrieval') + +ppf_direct_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r''' + #define M_PI 3.14159265358979323846264338327950288 + extern "C" __global__ + void ppf_voting_direct( + const float *points, const float *outputs, const float *probs, const int *point_idxs, float *grid_obj, const float *corner, const float res, + int n_ppfs, int grid_x, int grid_y, int grid_z + ) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_ppfs) { + int a_idx = point_idxs[idx]; + float3 c = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]); + float prob = probs[idx]; + float3 offset = make_float3(outputs[idx * 3], outputs[idx * 3 + 1], outputs[idx * 3 + 2]); + float3 center_grid = (c - offset - make_float3(corner[0], corner[1], corner[2])) / res; + if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 || + center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) { + return; + } + int3 center_grid_floor = make_int3(center_grid); + int3 center_grid_ceil = center_grid_floor + 1; + float3 residual = fracf(center_grid); + + float3 w0 = 1.f - residual; + float3 w1 = residual; + + float lll = w0.x * w0.y * w0.z; + float llh = w0.x * w0.y * w1.z; + float lhl = w0.x * w1.y * w0.z; + float lhh = w0.x * w1.y * w1.z; + float hll = w1.x * w0.y * w0.z; + float hlh = w1.x * w0.y * w1.z; + float hhl = w1.x * w1.y * w0.z; + float hhh = w1.x * w1.y * w1.z; + + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], lll * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], llh * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], lhl * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], lhh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_floor.z], hll * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_floor.y * grid_z + center_grid_ceil.z], hlh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_floor.z], hhl * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z + center_grid_ceil.y * grid_z + center_grid_ceil.z], hhh * prob); + } + } +''', 'ppf_voting_direct') + + +rot_voting_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r''' + #define M_PI 3.14159265358979323846264338327950288 + extern "C" __global__ + void rot_voting( + const float *points, const float *preds_rot, float3 *outputs_up, const int *point_idxs, + int n_ppfs, int n_rots + ) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_ppfs) { + float rot = preds_rot[idx]; + int a_idx = point_idxs[idx * 2]; + int b_idx = point_idxs[idx * 2 + 1]; + float3 a = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]); + float3 b = make_float3(points[b_idx * 3], points[b_idx * 3 + 1], points[b_idx * 3 + 2]); + float3 ab = a - b; + if (length(ab) < 1e-7) return; + ab /= length(ab); + + float3 co = make_float3(0.f, -ab.z, ab.y); + if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f); + float3 x = co / (length(co) + 1e-7); + float3 y = cross(x, ab); + + for (int i = 0; i < n_rots; i++) { + float angle = i * 2 * M_PI / n_rots; + float3 offset = cos(angle) * x + sin(angle) * y; + float3 up = tan(rot) * offset + (tan(rot) > 0 ? ab : -ab); + up = up / (length(up) + 1e-7); + outputs_up[idx * n_rots + i] = up; + } + } + } +''', 'rot_voting') + +rot_voting_retrieval_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r''' + #define M_PI 3.14159265358979323846264338327950288 + extern "C" __global__ + void rot_voting_retrieval( + const float *point_pairs, const float *preds_rot, float3 *outputs_up, + int n_ppfs, int n_rots + ) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_ppfs) { + float rot = preds_rot[idx]; + float3 a = make_float3(point_pairs[idx * 6], point_pairs[idx * 6 + 1], point_pairs[idx * 6 + 2]); + float3 b = make_float3(point_pairs[idx * 6 + 3], point_pairs[idx * 6 + 4], point_pairs[idx * 6 + 5]); + float3 ab = a - b; + if (length(ab) < 1e-7) return; + ab /= length(ab); + + float3 co = make_float3(0.f, -ab.z, ab.y); + if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f); + float3 x = co / (length(co) + 1e-7); + float3 y = cross(x, ab); + + for (int i = 0; i < n_rots; i++) { + float angle = i * 2 * M_PI / n_rots; + float3 offset = cos(angle) * x + sin(angle) * y; + float3 up = tan(rot) * offset + (tan(rot) > 0 ? ab : -ab); + up = up / (length(up) + 1e-7); + outputs_up[idx * n_rots + i] = up; + } + } + } +''', 'rot_voting_retrieval') + + +ppf4d_kernel = cp.RawKernel(f'#include "{helper_math_path}"\n' + r''' + #define M_PI 3.14159265358979323846264338327950288 + extern "C" __global__ + void ppf4d_voting( + const float *points, const float *outputs, const float *rot_outputs, const float *probs, const int *point_idxs, float *grid_obj, const float *corner, const float res, + int n_ppfs, int n_rots, int grid_x, int grid_y, int grid_z, int grid_w + ) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_ppfs) { + float proj_len = outputs[idx * 2]; + float odist = outputs[idx * 2 + 1]; + if (odist < res) return; + int a_idx = point_idxs[idx * 2]; + int b_idx = point_idxs[idx * 2 + 1]; + float3 a = make_float3(points[a_idx * 3], points[a_idx * 3 + 1], points[a_idx * 3 + 2]); + float3 b = make_float3(points[b_idx * 3], points[b_idx * 3 + 1], points[b_idx * 3 + 2]); + float3 ab = a - b; + if (length(ab) < 1e-7) return; + ab /= (length(ab) + 1e-7); + float3 c = a - ab * proj_len; + + // float prob = max(probs[a_idx], probs[b_idx]); + float prob = probs[idx]; + float3 co = make_float3(0.f, -ab.z, ab.y); + if (length(co) < 1e-7) co = make_float3(-ab.y, ab.x, 0.f); + float3 x = co / (length(co) + 1e-7) * odist; + float3 y = cross(x, ab); + int adaptive_n_rots = min(int(odist / res * (2 * M_PI)), n_rots); + // int adaptive_n_rots = n_rots; + for (int i = 0; i < adaptive_n_rots; i++) { + float angle = i * 2 * M_PI / adaptive_n_rots; + float3 offset = cos(angle) * x + sin(angle) * y; + float3 center_grid = (c + offset - make_float3(corner[0], corner[1], corner[2])) / res; + if (center_grid.x < 0.01 || center_grid.y < 0.01 || center_grid.z < 0.01 || + center_grid.x >= grid_x - 1.01 || center_grid.y >= grid_y - 1.01 || center_grid.z >= grid_z - 1.01) { + continue; + } + int3 center_grid_floor = make_int3(center_grid); + int3 center_grid_ceil = center_grid_floor + 1; + float3 residual = fracf(center_grid); + + float3 w0 = 1.f - residual; + float3 w1 = residual; + + float lll = w0.x * w0.y * w0.z; + float llh = w0.x * w0.y * w1.z; + float lhl = w0.x * w1.y * w0.z; + float lhh = w0.x * w1.y * w1.z; + float hll = w1.x * w0.y * w0.z; + float hlh = w1.x * w0.y * w1.z; + float hhl = w1.x * w1.y * w0.z; + float hhh = w1.x * w1.y * w1.z; + + for (int j = 0; j < grid_w; j++) { + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * lll * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * llh * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * lhl * prob); + atomicAdd(&grid_obj[center_grid_floor.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * lhh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * hll * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_floor.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * hlh * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_floor.z * grid_w + j], rot_outputs[idx * grid_w + j] * hhl * prob); + atomicAdd(&grid_obj[center_grid_ceil.x * grid_y * grid_z * grid_w + center_grid_ceil.y * grid_z * grid_w + center_grid_ceil.z * grid_w + j], rot_outputs[idx * grid_w + j] * hhh * prob); + } + } + } + } +''', 'ppf4d_voting') diff --git a/outputs/.gitkeep b/outputs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/real_eval.py b/real_eval.py new file mode 100644 index 0000000..4c2efca --- /dev/null +++ b/real_eval.py @@ -0,0 +1,257 @@ +import os +import time +import tqdm +import numpy as np +import transformations as tf + +from envs.real_camera import CameraL515 +from envs.real_robot import Panda +from utilities.data_utils import transform_pc, transform_dir +from utilities.vis_utils import visualize +from utilities.network_utils import send, read + + +if __name__ == '__main__': + print("please manually set the robot to a specific pose, make sure remote services all running") + + cam2EE_path = "/home/franka/junbo/data/robot/l515_franka.npy" + cam2EE = np.load(cam2EE_path) + camera_loaded = False + robot_loaded = False + vis = True + temp_observation_path = "./temp_data/observation.npz" + temp_service_path = "./temp_data/service.npz" + temp_flag_path = "./temp_data/flag.npy" + remote_repo_path = "TODO" # TODO: set your own remote repo path + remote_observation_path = f"{remote_repo_path}/temp_data/observation.npz" + remote_service_path = f"{remote_repo_path}/temp_data/service.npz" + remote_flag_path = f"{remote_repo_path}/temp_data/flag.npy" + remote_ip = "TODO" # TODO: set your own remote ip + port = TODO # TODO: set your own remote port + username = "TODO" # TODO: set your own remote username + key_filename = "TODO" # TODO: set your own local key file path + task = -1 # close as 1, open as -1 + time_steps = 10 + + try: + print("===> initializing camera") + start_time = time.time() + camera = CameraL515() + camera_loaded = True + end_time = time.time() + print("===> camera initialized", end_time - start_time) + + print("===> initializing robot") + start_time = time.time() + robot = Panda() + robot.gripper_open() + robot.homing() + robot_loaded = True + end_time = time.time() + print("===> robot initialized", end_time - start_time) + + print("===> getting observation") + start_time = time.time() + color, depth = camera.get_data(hole_filling=False) + depth_sensor = camera.pipeline_profile.get_device().first_depth_sensor() + depth_scale = depth_sensor.get_depth_scale() + xyzrgb = camera.getXYZRGB(color, depth, np.identity(4), np.identity(4), camera.getIntrinsics(), inpaint=False, depth_scale=depth_scale) + # xyzrgb = xyzrgb[xyzrgb[:, 2] <= 1.5, :] + xyzrgb = xyzrgb[xyzrgb[:, 2] > 0.05, :] + cam_pc = xyzrgb[:, 0:3] + pc_color = xyzrgb[:, 3:6] + end_time = time.time() + print("===> observation got", end_time - start_time) + + print("===> preprocessing observation") + start_time = time.time() + EE2robot = robot.readPose() + cam2robot = EE2robot @ cam2EE + robot2cam = np.linalg.inv(cam2robot) + base_pc = transform_pc(cam_pc, cam2robot) + space_mask_x = np.logical_and(base_pc[:, 0] > 0, base_pc[:, 0] < 1.1) + space_mask_y = np.logical_and(base_pc[:, 1] > -0.27, base_pc[:, 1] < 0.55) + # space_mask_z = base_pc[:, 2] > 0.02 + # space_mask_z = base_pc[:, 2] > 0.55 # microwave: pad + safe (rotate) + # space_mask_z = base_pc[:, 2] > 0.52 # refrigerator: storagefurniture + # space_mask_z = base_pc[:, 2] > 0.4 # safe: pad + microwave + # space_mask_z = base_pc[:, 2] > 0.27 # storagefurniture: microwave + # space_mask_z = base_pc[:, 2] > 0.27 # drawer: microwave + space_mask_z = base_pc[:, 2] > 0.4 # washingmachine: pad + microwave + space_mask = np.logical_and(np.logical_and(space_mask_x, space_mask_y), space_mask_z) + base_pc_space = base_pc[space_mask, :] + pc_color_space = pc_color[space_mask, :] + cam_pc_space = transform_pc(base_pc_space, robot2cam) + end_time = time.time() + print("===> observation preprocessed", end_time - start_time) + np.savez("./observation.npz", point_cloud=cam_pc_space, rgb=pc_color_space) + if vis: + visualize(cam_pc_space, pc_color_space, whether_frame=True, whether_bbox=True, window_name="observation") + + print("===> sending request") + start_time = time.time() + np.savez(temp_observation_path, point_cloud=cam_pc_space, rgb=pc_color_space) + time.sleep(0.5) + while not (os.path.isfile(temp_observation_path) and os.access(temp_observation_path, os.R_OK)): + time.sleep(0.1) + send(temp_observation_path, remote_observation_path, + remote_ip=remote_ip, port=port, username=username, key_filename=key_filename) + time.sleep(0.5) + os.remove(temp_observation_path) + end_time = time.time() + print("===> request sent", end_time - start_time) + + print("===> reading response") + start_time = time.time() + while True: + read(temp_flag_path, remote_flag_path, + remote_ip=remote_ip, port=port, username=username, key_filename=key_filename) + time.sleep(0.5) + got_service = np.load(temp_flag_path).item() + if got_service: + os.remove(temp_flag_path) + break + else: + time.sleep(0.5) + read(temp_service_path, remote_service_path, + remote_ip=remote_ip, port=port, username=username, key_filename=key_filename) + time.sleep(0.5) + service = np.load(temp_service_path, allow_pickle=True) + num_grasps = service['num_grasps'] + if num_grasps == 0: + print("no grasps detected") + else: + cam_joint_base = service['joint_base'] + cam_joint_direction = service['joint_direction'] + cam_affordable_position = service['affordable_position'] + joint_type = service['joint_type'] + joint_re = service['joint_re'] + grasp_score = service['grasp_score'] + grasp_width = service['grasp_width'] + grasp_depth = service['grasp_depth'] + grasp_affordance = service['grasp_affordance'] + cam_grasp_translation = service['grasp_translation'] + cam_grasp_rotation = service['grasp_rotation'] + cam_grasp_pose = np.eye(4) + cam_grasp_pose[:3, 3] = cam_grasp_translation + cam_grasp_pose[:3, :3] = cam_grasp_rotation + base_joint_base = transform_pc(cam_joint_base[None, :], cam2robot)[0] + base_joint_direction = transform_dir(cam_joint_direction[None, :], cam2robot)[0] + base_affordable_position = transform_pc(cam_affordable_position[None, :], cam2robot)[0] + base_grasp_pose = cam2robot @ cam_grasp_pose + base_grasp_pose[:3, 3] += (grasp_depth - 0.05) * base_grasp_pose[:3, 0] # TODO: hardcode to avoid collision + if joint_type == 0: + # TODO: only for horizontal grasp to avoid singular robot state + flip = np.arccos(np.dot(base_grasp_pose[:3, 2], np.array([0., 0., 1.]))) / np.pi * 180.0 < 45 + if flip: + print("flipped") + base_grasp_pose[:3, 1] = -base_grasp_pose[:3, 1] + base_grasp_pose[:3, 2] = -base_grasp_pose[:3, 2] + rotate = base_grasp_pose[:3, 0][2] > 0 + if rotate: + print("rotated") + target_x_axis = base_grasp_pose[:3, 0].copy() + target_x_axis[2] = -target_x_axis[2] + rotation_angle = np.arccos(np.dot(base_grasp_pose[:3, 0], target_x_axis)) + rotation_direction = np.array([base_grasp_pose[:3, 0][0], base_grasp_pose[:3, 0][1]]) + rotation_direction /= np.linalg.norm(rotation_direction) + rotation_direction = np.array([-rotation_direction[1], rotation_direction[0], 0.]) + rotation_matrix = tf.rotation_matrix(angle=rotation_angle, direction=rotation_direction, point=base_grasp_pose[:3, 3]) + base_grasp_pose = rotation_matrix @ base_grasp_pose + elif joint_type == 1: + horizontal = np.arccos(np.dot(base_grasp_pose[:3, 0], np.array([1., 0., 0.]))) / np.pi * 180.0 < 45 + if horizontal: + print("horizontal") + else: + print("vertical") + else: + raise ValueError + base_pre_grasp_pose = base_grasp_pose.copy() + base_pre_grasp_pose[:3, 3] -= 0.05 * base_pre_grasp_pose[:3, 0] + g2g = np.array([[0., 0., -1.], [0., -1., 0.], [-1., 0., 0.]]) + base_gripper_pose = np.eye(4) + base_gripper_pose[:3, :3] = base_grasp_pose[:3, :3] @ g2g + base_gripper_pose[:3, 3] = base_grasp_pose[:3, 3] + base_pre_gripper_pose = np.eye(4) + base_pre_gripper_pose[:3, :3] = base_pre_grasp_pose[:3, :3] @ g2g + base_pre_gripper_pose[:3, 3] = base_pre_grasp_pose[:3, 3] + np.savez("./joint.npz", joint_base=base_joint_base, joint_direction=base_joint_direction, affordable_position=base_affordable_position) + np.savez("./grasp.npz", grasp_pose=base_grasp_pose, grasp_width=grasp_width) + time.sleep(0.5) + os.remove(temp_service_path) + end_time = time.time() + print("===> response read", end_time - start_time) + if vis: + if num_grasps != 0: + visualize(base_pc_space, pc_color_space, + joint_translations=base_joint_base[None, :], joint_rotations=base_joint_direction[None, :], affordable_positions=base_affordable_position[None, :], + grasp_poses=base_grasp_pose[None, ...], grasp_widths=np.array([grasp_width]), grasp_depths=np.array([0.]), grasp_affordances=np.array([grasp_affordance]), + whether_frame=True, whether_bbox=True, window_name="prediction") + + print("===> resetting flag") + serviced = np.array(False) + np.save(temp_flag_path, serviced) + time.sleep(0.5) + while not (os.path.isfile(temp_flag_path) and os.access(temp_flag_path, os.R_OK)): + time.sleep(0.1) + send(temp_flag_path, remote_flag_path, + remote_ip=remote_ip, port=port, username=username, key_filename=key_filename) + time.sleep(0.5) + os.remove(temp_flag_path) + end_time = time.time() + print("===> flag reset", end_time - start_time) + + print("===> starting manipulation") + import pdb; pdb.set_trace() + start_time = time.time() + if num_grasps == 0: + exit(1) + else: + robot.move_gripper(grasp_width) + + robot.movePose(base_pre_gripper_pose) + + robot.movePose(base_gripper_pose) + + # grasp + is_graspped = robot.gripper_close() + is_graspped = is_graspped and robot.is_grasping() + print(is_graspped) + + # move + real_trajectory = [] + target_trajectory = [] + wrench_trajectory = [] + robot.start_impedance_control() + for time_step in tqdm.trange(time_steps): + current_EE2robot = robot.readPose() + current_wrench= robot.readWrench() + if joint_type == 0: + rotation_angle = -5.0 * task * joint_re / 180.0 * np.pi + delta_pose = tf.rotation_matrix(angle=rotation_angle, direction=base_joint_direction, point=base_joint_base) + elif joint_type == 1: + translation_distance = -5.0 * task / 100.0 + delta_pose = tf.translation_matrix(base_joint_direction * translation_distance) + else: + raise ValueError + target_EE2robot = delta_pose @ current_EE2robot + robot.movePose(target_EE2robot) + time.sleep(0.3) + real_trajectory.append(current_EE2robot) + target_trajectory.append(target_EE2robot) + wrench_trajectory.append(current_wrench) + robot.end_impedance_control() + real_trajectory = np.array(real_trajectory) + target_trajectory = np.array(target_trajectory) + wrench_trajectory = np.array(wrench_trajectory) + np.savez("./trajectory.npz", real_trajectory=real_trajectory, target_trajectory=target_trajectory, wrench_trajectory=wrench_trajectory) + end_time = time.time() + print("===> manipulation done", end_time - start_time) + robot.gripper_open() + # robot.homing() + except Exception as e: + print(e) + if camera_loaded: + del camera + if robot_loaded: + del robot diff --git a/real_service.py b/real_service.py new file mode 100644 index 0000000..4ba987a --- /dev/null +++ b/real_service.py @@ -0,0 +1,278 @@ +import os +import configargparse +from omegaconf import OmegaConf +import time +import numpy as np +import torch + +from utilities.env_utils import setup_seed +from utilities.data_utils import transform_pc, transform_dir +from utilities.metrics_utils import invaffordance_metrics, invaffordances2affordance +from utilities.constants import seed, max_grasp_width + + +def config_parse() -> configargparse.Namespace: + parser = configargparse.ArgumentParser() + + # data config + parser.add_argument('--cat', type=str, default='Microwave', help='the category of the object') + # model config + parser.add_argument('--roartnet', action='store_true', help='whether call roartnet') + parser.add_argument('--roartnet_config_path', type=str, default='./configs/eval_config.yaml', help='the path to roartnet config') + # grasp config + parser.add_argument('--graspnet', action='store_true', help='whether call graspnet') + parser.add_argument('--gsnet_weight_path', type=str, default='./weights/checkpoint_detection.tar', help='the path to graspnet weight') + parser.add_argument('--max_grasp_width', type=float, default=max_grasp_width, help='the max width of the gripper') + # task config + parser.add_argument('--selected_part', type=int, default=0, help='the selected part of the object') + # others + parser.add_argument('--seed', type=int, default=seed, help='the random seed') + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + print("please clear temporary data directory") + + args = config_parse() + setup_seed(args.seed) + temp_request_path = './temp_data/observation.npz' + temp_response_path = './temp_data/service.npz' + temp_flag_path = './temp_data/flag.npy' + if args.cat == "Microwave": + joint_types = [0] + joint_res = [-1] + elif args.cat == "Refrigerator": + joint_types = [0] + joint_res = [1] + elif args.cat == "Safe": + joint_types = [0] + joint_res = [1] + elif args.cat == "StorageFurniture": + joint_types = [1, 0] + joint_res = [0, -1] + elif args.cat == "Drawer": + joint_types = [1, 1, 1] + joint_res = [0, 0, 0] + elif args.cat == "WashingMachine": + joint_types = [0] + joint_res = [-1] + else: + raise ValueError(f"Unknown category {args.cat}") + + if args.roartnet: + print("===> loading roartnet") + start_time = time.time() + from models.roartnet import create_shot_encoder, create_encoder + from inference import inference_fn as roartnet_inference_fn + roartnet_cfg = OmegaConf.load(args.roartnet_config_path) + trained_path = roartnet_cfg.trained.path[args.cat] + trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml") + roartnet_cfg = OmegaConf.merge(trained_cfg, roartnet_cfg) + joint_num = roartnet_cfg.dataset.joint_num + resolution = roartnet_cfg.dataset.resolution + receptive_field = roartnet_cfg.dataset.receptive_field + has_rgb = roartnet_cfg.dataset.rgb + denoise = roartnet_cfg.dataset.denoise + normalize = roartnet_cfg.dataset.normalize + sample_points_num = roartnet_cfg.dataset.sample_points_num + sample_tuples_num = roartnet_cfg.algorithm.sampling.sample_tuples_num + tuple_more_num = roartnet_cfg.algorithm.sampling.tuple_more_num + shot_hidden_dims = roartnet_cfg.algorithm.shot_encoder.hidden_dims + shot_feature_dim = roartnet_cfg.algorithm.shot_encoder.feature_dim + shot_bn = roartnet_cfg.algorithm.shot_encoder.bn + shot_ln = roartnet_cfg.algorithm.shot_encoder.ln + shot_dropout = roartnet_cfg.algorithm.shot_encoder.dropout + shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim, + shot_bn, shot_ln, shot_dropout) + shot_encoder.load_state_dict(torch.load(f'{trained_path}/weights/shot_encoder_latest.pth', map_location=torch.device('cuda'))) + shot_encoder = shot_encoder.cuda() + shot_encoder.eval() + overall_hidden_dims = roartnet_cfg.algorithm.encoder.hidden_dims + rot_bin_num = roartnet_cfg.algorithm.voting.rot_bin_num + overall_bn = roartnet_cfg.algorithm.encoder.bn + overall_ln = roartnet_cfg.algorithm.encoder.ln + overall_dropout = roartnet_cfg.algorithm.encoder.dropout + encoder = create_encoder(tuple_more_num, shot_feature_dim, has_rgb, overall_hidden_dims, rot_bin_num, joint_num, + overall_bn, overall_ln, overall_dropout) + encoder.load_state_dict(torch.load(f'{trained_path}/weights/encoder_latest.pth', map_location=torch.device('cuda'))) + encoder = encoder.cuda() + encoder.eval() + voting_num = roartnet_cfg.algorithm.voting.voting_num + angle_tol = roartnet_cfg.algorithm.voting.angle_tol + translation2pc = roartnet_cfg.algorithm.voting.translation2pc + multi_candidate = roartnet_cfg.algorithm.voting.multi_candidate + candidate_threshold = roartnet_cfg.algorithm.voting.candidate_threshold + rotation_multi_neighbor = roartnet_cfg.algorithm.voting.rotation_multi_neighbor + neighbor_threshold = roartnet_cfg.algorithm.voting.neighbor_threshold + rotation_cluster = roartnet_cfg.algorithm.voting.rotation_cluster + bmm_size = roartnet_cfg.algorithm.voting.bmm_size + end_time = time.time() + print(f"===> loaded roartnet {end_time - start_time}") + + if args.graspnet: + print("===> loading graspnet") + start_time = time.time() + from munch import DefaultMunch + from gsnet import AnyGrasp + grasp_detector_cfg = { + 'checkpoint_path': args.gsnet_weight_path, + 'max_gripper_width': args.max_grasp_width, + 'gripper_height': 0.03, + 'top_down_grasp': False, + 'add_vdistance': True, + 'debug': True + } + grasp_detector_cfg = DefaultMunch.fromDict(grasp_detector_cfg) + grasp_detector = AnyGrasp(grasp_detector_cfg) + grasp_detector.load_net() + end_time = time.time() + print(f"===> loaded graspnet {end_time - start_time}") + + while True: + print("===> listening to request") + start_time = time.time() + serviced = np.array(False) + np.save(temp_flag_path, serviced) + got_request = False + while not got_request: + got_request = os.path.exists(temp_request_path) + time.sleep(5.0) # NOTE: hardcode to be longer than the writing time + if got_request: + while not (os.path.isfile(temp_request_path) and os.access(temp_request_path, os.R_OK)): + time.sleep(0.1) + else: + time.sleep(0.1) + observation = np.load(temp_request_path, allow_pickle=True) + cam_pc = observation['point_cloud'] + pc_rgb = observation['rgb'] + c2c = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) + cam_pc_model = transform_pc(cam_pc, c2c) + time.sleep(0.5) + os.remove(temp_request_path) + end_time = time.time() + print(f"===> got request {end_time - start_time}") + + print("===> inferencing model") + if args.roartnet: + start_time = time.time() + pred_joint_bases, pred_joint_directions, pred_affordable_positions = roartnet_inference_fn(cam_pc_model, pc_rgb if has_rgb else None, shot_encoder, encoder, + denoise, normalize, resolution, receptive_field, sample_points_num, sample_tuples_num, tuple_more_num, + voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, bmm_size, joint_num, device=0) + pred_selected_joint_base = pred_joint_bases[args.selected_part] + pred_selected_joint_direction = pred_joint_directions[args.selected_part] + pred_selected_affordable_position = pred_affordable_positions[args.selected_part] + pred_selected_joint_base = transform_pc(pred_selected_joint_base[None, :], np.linalg.inv(c2c))[0] + pred_selected_joint_direction = transform_dir(pred_selected_joint_direction[None, :], np.linalg.inv(c2c))[0] + pred_selected_affordable_position = transform_pc(pred_selected_affordable_position[None, :], np.linalg.inv(c2c))[0] + end_time = time.time() + print(f"===> shot_dropout predicted {end_time - start_time}") + + print("===> detecting grasps") + if args.graspnet: + start_time = time.time() + # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=None, lims=[ + # np.floor(np.min(pcd_grasp[:, 0])) - 0.1, np.ceil(np.max(pcd_grasp[:, 0])) + 0.1, + # np.floor(np.min(pcd_grasp[:, 1])) - 0.1, np.ceil(np.max(pcd_grasp[:, 1])) + 0.1, + # np.floor(np.min(pcd_grasp[:, 2])) - 0.1, np.ceil(np.max(pcd_grasp[:, 2])) + 0.1]) + # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=[-float('inf'), float('inf'), -float('inf'), float('inf'), -float('inf'), float('inf')]) + # gg_grasp = grasp_detector.get_grasp(pcd_grasp.astype(np.float32), colors=pcd_color, lims=None, apply_object_mask=True, dense_grasp=False, collision_detection=True) + try: + gg_grasp = grasp_detector.get_grasp(cam_pc.astype(np.float32), colors=pc_rgb, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='fast') + except: + gg_grasp = grasp_detector.get_grasp(cam_pc.astype(np.float32), colors=pc_rgb, lims=None, voxel_size=0.0075, apply_object_mask=False, dense_grasp=True, collision_detection='slow') + if gg_grasp is None: + gg_grasp = [] + else: + if len(gg_grasp) != 2: + gg_grasp = [] + else: + gg_grasp, pcd_o3d = gg_grasp + gg_grasp = gg_grasp.nms().sort_by_score() + # grippers_o3d = gg_grasp.to_open3d_geometry_list() + # frame_o3d = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + # o3d.visualization.draw_geometries([*grippers_o3d, pcd_o3d, frame_o3d]) + if len(gg_grasp) == 0: + np.savez(temp_response_path, + joint_base=pred_selected_joint_base, + joint_direction=pred_selected_joint_direction, + affordable_position=pred_selected_affordable_position, + joint_type=joint_types[args.selected_part], + joint_re=joint_res[args.selected_part], + num_grasps=0) + while not (os.path.isfile(temp_response_path) and os.access(temp_response_path, os.R_OK)): + time.sleep(0.1) + serviced = np.array(True) + np.save(temp_flag_path, serviced) + while True: + while not (os.path.isfile(temp_flag_path) and os.access(temp_flag_path, os.R_OK)): + time.sleep(0.1) + serviced = np.load(temp_flag_path).item() + if not serviced: + os.remove(temp_flag_path) + os.remove(temp_response_path) + break + else: + time.sleep(0.1) + continue + grasp_scores, grasp_widths, grasp_depths, grasp_translations, grasp_rotations, grasp_invaffordances = [], [], [], [], [], [] + for g_idx, g_grasp in enumerate(gg_grasp): + grasp_score = g_grasp.score + grasp_scores.append(grasp_score) + grasp_width = g_grasp.width + grasp_widths.append(grasp_width) + grasp_depth = g_grasp.depth + grasp_depths.append(grasp_depth) + grasp_translation = g_grasp.translation + grasp_translations.append(grasp_translation) + grasp_rotation = g_grasp.rotation_matrix + grasp_rotations.append(grasp_rotation) + grasp_invaffordance = invaffordance_metrics(grasp_translation, grasp_rotation, grasp_score, pred_selected_affordable_position, + pred_selected_joint_base, pred_selected_joint_direction, joint_types[args.selected_part]) + grasp_invaffordances.append(grasp_invaffordance) + grasp_affordances = invaffordances2affordance(grasp_invaffordances) + selected_grasp_idx = np.argmax(grasp_affordances) + selected_grasp_score = grasp_scores[selected_grasp_idx] + selected_grasp_width = grasp_widths[selected_grasp_idx] + selected_grasp_width = max(min(selected_grasp_width * 1.5, args.max_grasp_width), 0.0) + selected_grasp_depth = grasp_depths[selected_grasp_idx] + selected_grasp_translation = grasp_translations[selected_grasp_idx] + selected_grasp_rotation = grasp_rotations[selected_grasp_idx] + selected_grasp_affordance = grasp_affordances[selected_grasp_idx] + end_time = time.time() + print(f"===> anygrasp detected {end_time - start_time} {len(gg_grasp)}") + + print("===> sending response") + start_time = time.time() + np.savez(temp_response_path, + joint_base=pred_selected_joint_base, + joint_direction=pred_selected_joint_direction, + affordable_position=pred_selected_affordable_position, + joint_type=joint_types[args.selected_part], + joint_re=joint_res[args.selected_part], + num_grasps=len(gg_grasp), + grasp_score=selected_grasp_score, + grasp_width=selected_grasp_width, + grasp_depth=selected_grasp_depth, + grasp_translation=selected_grasp_translation, + grasp_rotation=selected_grasp_rotation, + grasp_affordance=selected_grasp_affordance) + while not (os.path.isfile(temp_response_path) and os.access(temp_response_path, os.R_OK)): + time.sleep(0.1) + serviced = np.array(True) + np.save(temp_flag_path, serviced) + while True: + while not (os.path.isfile(temp_flag_path) and os.access(temp_flag_path, os.R_OK)): + time.sleep(0.1) + serviced = np.load(temp_flag_path).item() + if not serviced: + os.remove(temp_flag_path) + os.remove(temp_response_path) + break + else: + time.sleep(0.1) + end_time = time.time() + print(f"===> sent response {end_time - start_time}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..14e8b37 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +requests +ConfigArgParse +trimesh +tqdm +Pillow +scipy +open3d==0.18.0 +opencv-python +omegaconf +hydra-core +scikit-learn +tensorboard +transformations +paramiko diff --git a/scripts/eval_gt.sh b/scripts/eval_gt.sh new file mode 100644 index 0000000..7d63041 --- /dev/null +++ b/scripts/eval_gt.sh @@ -0,0 +1,11 @@ +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task pull --abbr 7119 --seed 42 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task push --abbr 7119 --seed 43 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task pull --abbr 7263 --seed 44 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task push --abbr 7263 --seed 45 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task pull --abbr 7296 --seed 46 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --graspnet --grasp --selected_part 0 --task push --abbr 7296 --seed 47 diff --git a/scripts/eval_roartnet.sh b/scripts/eval_roartnet.sh new file mode 100644 index 0000000..26fb479 --- /dev/null +++ b/scripts/eval_roartnet.sh @@ -0,0 +1,11 @@ +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task pull --abbr 7119 --seed 42 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7119/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7119/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task push --abbr 7119 --seed 43 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task pull --abbr 7263 --seed 44 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7263/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7263/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task push --abbr 7263 --seed 45 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task pull --abbr 7296 --seed 46 + +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python env_eval.py --num_config_per_object 100 --data_path /data2/junbo/where2act_modified_sapien_dataset/7296/mobility_vhacd.urdf --cat Microwave --gt_path /data2/junbo/where2act_modified_sapien_dataset/7296/joint_abs_pose.json --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --grasp --selected_part 0 --task push --abbr 7296 --seed 47 diff --git a/scripts/real_service.sh b/scripts/real_service.sh new file mode 100644 index 0000000..f684545 --- /dev/null +++ b/scripts/real_service.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; CUDA_VISIBLE_DEVICES=0 python real_service.py --cat Microwave --roartnet --roartnet_config_path configs/eval_config.yaml --graspnet --selected_part 0 --seed 42 diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000..a4a095a --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python test.py diff --git a/scripts/test_config.sh b/scripts/test_config.sh new file mode 100644 index 0000000..7d0304e --- /dev/null +++ b/scripts/test_config.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python test.py --cfg job diff --git a/scripts/test_gt.sh b/scripts/test_gt.sh new file mode 100644 index 0000000..7e734e0 --- /dev/null +++ b/scripts/test_gt.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python test_gt.py diff --git a/scripts/test_gt_config.sh b/scripts/test_gt_config.sh new file mode 100644 index 0000000..6edeae5 --- /dev/null +++ b/scripts/test_gt_config.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python test_gt.py --cfg job diff --git a/scripts/test_real.sh b/scripts/test_real.sh new file mode 100644 index 0000000..21dda9b --- /dev/null +++ b/scripts/test_real.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python test_real.py diff --git a/scripts/test_real_config.sh b/scripts/test_real_config.sh new file mode 100644 index 0000000..8248ea1 --- /dev/null +++ b/scripts/test_real_config.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python test_real.py --cfg job diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000..5c46993 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python train.py diff --git a/scripts/train_config.sh b/scripts/train_config.sh new file mode 100644 index 0000000..58a18ce --- /dev/null +++ b/scripts/train_config.sh @@ -0,0 +1 @@ +export OMP_NUM_THREADS=16; python train.py --cfg job diff --git a/src_shot/CMakeLists.txt b/src_shot/CMakeLists.txt new file mode 100644 index 0000000..79ba892 --- /dev/null +++ b/src_shot/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 2.8.12) +project(shot) +set (CMAKE_CXX_STANDARD 11) + +find_package( PythonInterp 3.6 REQUIRED ) +find_package( PythonLibs 3.6 REQUIRED ) +find_package( pybind11 REQUIRED ) +find_package( PCL 1.8 REQUIRED ) + +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + +include_directories( ${PCL_INCLUDE_DIRS} ) +# link_directories( ${PCL_LIBRARY_DIRS} ) +add_definitions(${PCL_DEFINITIONS}) + +pybind11_add_module(shot shot.cpp) +target_link_libraries(shot PUBLIC ${PCL_LIBRARIES}) diff --git a/src_shot/shot.cpp b/src_shot/shot.cpp new file mode 100644 index 0000000..02af856 --- /dev/null +++ b/src_shot/shot.cpp @@ -0,0 +1,165 @@ +/* +Modified from https://github.com/qq456cvb/CPPF2/blob/main/src_shot/shot.cpp +*/ +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace pybind11::literals; + +py::array_t estimate_normal(py::array_t pc, double normal_r) +{ + pcl::PointCloud::Ptr cloud(new pcl::PointCloud); + cloud->points.resize(pc.shape(0)); + float *pc_ptr = (float*)pc.request().ptr; + for (int i = 0; i < pc.shape(0); ++i) + { + std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]); + // std::cout << cloud->points[i] << std::endl; + pc_ptr += 3; + } + pcl::PointCloud::Ptr normals(new pcl::PointCloud); + + pcl::NormalEstimation normalEstimation; + normalEstimation.setInputCloud(cloud); + normalEstimation.setRadiusSearch(normal_r); + // normalEstimation.setKSearch(40); + pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree); + normalEstimation.setSearchMethod(kdtree); + normalEstimation.compute(*normals); + + auto result = py::array_t(normals->points.size() * 3); + auto buf = result.request(); + float *ptr = (float*)buf.ptr; + for (int i = 0; i < normals->points.size(); ++i) + { + std::copy(&normals->points[i].normal[0], &normals->points[i].normal[3], &ptr[i * 3]); + } + return result; +} + + +py::array_t compute(py::array_t pc, double normal_r, double shot_r) +{ + // Object for storing the point cloud. + pcl::PointCloud::Ptr cloud(new pcl::PointCloud); + cloud->points.resize(pc.shape(0)); + float *pc_ptr = (float*)pc.request().ptr; + for (int i = 0; i < pc.shape(0); ++i) + { + std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]); + // std::cout << cloud->points[i] << std::endl; + pc_ptr += 3; + } + + // Object for storing the normals. + pcl::PointCloud::Ptr normals(new pcl::PointCloud); + // Object for storing the SHOT descriptors for each point. + pcl::PointCloud::Ptr descriptors(new pcl::PointCloud()); + + // Note: you would usually perform downsampling now. It has been omitted here + // for simplicity, but be aware that computation can take a long time. + + // Estimate the normals. + pcl::NormalEstimation normalEstimation; + normalEstimation.setInputCloud(cloud); + normalEstimation.setRadiusSearch(normal_r); + // normalEstimation.setKSearch(40); + pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree); + normalEstimation.setSearchMethod(kdtree); + normalEstimation.compute(*normals); + + // SHOT estimation object. + pcl::SHOTEstimation shot; + shot.setInputCloud(cloud); + shot.setInputNormals(normals); + // The radius that defines which of the keypoint's neighbors are described. + // If too large, there may be clutter, and if too small, not enough points may be found. + shot.setRadiusSearch(shot_r); +// shot.setKSearch(40); + shot.compute(*descriptors); + + auto result = py::array_t(descriptors->points.size() * 352); + auto buf = result.request(); + float *ptr = (float*)buf.ptr; + + for (int i = 0; i < descriptors->points.size(); ++i) + { + std::copy(&descriptors->points[i].descriptor[0], &descriptors->points[i].descriptor[352], &ptr[i * 352]); + } + return result; +} + +py::array_t compute_color(py::array_t pc, py::array_t pc_color, double normal_r, double shot_r) +{ + // Object for storing the point cloud. + pcl::PointCloud::Ptr cloud(new pcl::PointCloud); + cloud->points.resize(pc.shape(0)); + float *pc_ptr = (float*)pc.request().ptr; + float *color_ptr = (float*)pc_color.request().ptr; + for (int i = 0; i < pc.shape(0); ++i) + { + cloud->points[i].x = *pc_ptr; + cloud->points[i].y = *(pc_ptr + 1); + cloud->points[i].z = *(pc_ptr + 2); + + uint8_t r = (*color_ptr) * 255.f; + uint8_t g = (*(color_ptr + 1)) * 255.f; + uint8_t b = (*(color_ptr + 2)) * 255.f; + uint32_t rgb = ((std::uint32_t)r << 16 | (std::uint32_t)g << 8 | (std::uint32_t)b); + cloud->points[i].rgb = *reinterpret_cast(&rgb); + // std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]); + // std::copy(color_ptr, color_ptr + 3, &cloud->points[i].data[3]); + pc_ptr += 3; + color_ptr += 3; + } + + // Object for storing the normals. + pcl::PointCloud::Ptr normals(new pcl::PointCloud); + // Object for storing the SHOT descriptors for each point. + pcl::PointCloud::Ptr descriptors(new pcl::PointCloud()); + + // Note: you would usually perform downsampling now. It has been omitted here + // for simplicity, but be aware that computation can take a long time. + + // Estimate the normals. + pcl::NormalEstimation normalEstimation; + normalEstimation.setInputCloud(cloud); + normalEstimation.setRadiusSearch(normal_r); + // normalEstimation.setKSearch(40); + pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree); + normalEstimation.setSearchMethod(kdtree); + normalEstimation.compute(*normals); + + // SHOT estimation object. + pcl::SHOTColorEstimation shot; + shot.setInputCloud(cloud); + shot.setInputNormals(normals); + // The radius that defines which of the keypoint's neighbors are described. + // If too large, there may be clutter, and if too small, not enough points may be found. + shot.setRadiusSearch(shot_r); + shot.compute(*descriptors); + + auto result = py::array_t(descriptors->points.size() * 1344); + auto buf = result.request(); + float *ptr = (float*)buf.ptr; + + for (int i = 0; i < descriptors->points.size(); ++i) + { + std::copy(&descriptors->points[i].descriptor[0], &descriptors->points[i].descriptor[1344], &ptr[i * 1344]); + } + return result; +} + + +PYBIND11_MODULE(shot, m) { + pcl::console::setVerbosityLevel(pcl::console::L_ALWAYS); + m.def("compute", &compute, py::arg("pc"), py::arg("normal_r")=0.1, py::arg("shot_r")=0.17); + m.def("compute_color", &compute_color, py::arg("pc"), py::arg("pc_color"), py::arg("normal_r")=0.1, py::arg("shot_r")=0.17); + m.def("estimate_normal", &estimate_normal, py::arg("pc"), py::arg("normal_r")=0.1); +} diff --git a/temp_data/.gitkeep b/temp_data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test.py b/test.py new file mode 100644 index 0000000..517daa2 --- /dev/null +++ b/test.py @@ -0,0 +1,387 @@ +from typing import Dict, Union +import hydra +from omegaconf import DictConfig, OmegaConf +import logging +import os +import time +import tqdm +from itertools import combinations +import numpy as np +import torch +import torch.nn as nn +from sklearn.cluster import KMeans + +from datasets.rconfmask_afford_point_tuple_dataset import ArticulationDataset +from models.roartnet import create_shot_encoder, create_encoder +from inference import voting_translation, voting_rotation +from utilities.metrics_utils import calc_translation_error, calc_translation_error_batch, calc_direction_error, calc_direction_error_batch, log_metrics +from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting, visualize_confidence_voting +from utilities.env_utils import setup_seed +from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color + + +def test_fn(test_dataloader:torch.utils.data.DataLoader, has_rgb:bool, shot_encoder:nn.Module, encoder:nn.Module, + resolution:float, voting_num:int, rot_bin_num:int, angle_tol:float, + translation2pc:bool, multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool, + rotation_multi_neighbor:bool, neighbor_threshold:float, + bmm_size:int, test_num:int, device:int, vis:bool=False) -> Dict[str, Union[np.ndarray, int]]: + if rotation_cluster: + kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto') + else: + kmeans = None + rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi)) + + tested_num = 0 + with torch.no_grad(): + translation_distance_errors = [] + translation_along_errors = [] + translation_perp_errors = [] + translation_plane_errors = [] + translation_line_errors = [] + translation_outliers = [] + rotation_errors = [] + rotation_outliers = [] + affordance_errors = [] + affordance_outliers = [] + for batch_data in tqdm.tqdm(test_dataloader): + if tested_num >= test_num: + break + if has_rgb: + pcs, pc_normals, pc_shots, pc_colors, joint_translations, joint_rotations, affordable_positions, _, _, _, _, point_idxs_all = batch_data + pcs, pc_normals, pc_shots, pc_colors, point_idxs_all = \ + pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), pc_colors.cuda(device), point_idxs_all.cuda(device) + else: + pcs, pc_normals, pc_shots, joint_translations, joint_rotations, affordable_positions, _, _, _, _, point_idxs_all = batch_data + pcs, pc_normals, pc_shots, point_idxs_all = \ + pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), point_idxs_all.cuda(device) + # (B, N, 3), (B, N, 3), (B, N, 352)(, (B, N, 3)), (B, J, 3), (B, J, 3), (B, J, 3), (B, N_t, 2 + N_m) + B = pcs.shape[0] + N = pcs.shape[1] + J = joint_translations.shape[1] + N_t = point_idxs_all.shape[1] + tested_num += B + + # shot encoder for every point + shot_feat = shot_encoder(pc_shots) # (B, N, N_s) + + # encoder for sampled point tuples + # shot_inputs = torch.cat([shot_feat[point_idxs_all[:, i]] for i in range(0, point_idxs_all.shape[-1])], -1) # (sample_points, feature_dim * (2 + num_more)) + # normal_inputs = torch.cat([torch.max(torch.sum(normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True), + # torch.sum(-normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True)) + # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, (2+num_more \choose 2)) + # coord_inputs = torch.cat([pc[point_idxs_all[:, i]] - pc[point_idxs_all[:, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, 3 * (2+num_more \choose 2)) + # shot_inputs = [] + # normal_inputs = [] + # coord_inputs = [] + # for b in range(pcs.shape[0]): + # shot_inputs.append(torch.cat([shot_feat[b][point_idxs_all[b, :, i]] for i in range(0, point_idxs_all.shape[-1])], dim=-1)) # (sample_points, feature_dim * (2 + num_more)) + # normal_inputs.append(torch.cat([torch.max(torch.sum(normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True), + # torch.sum(-normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True)) + # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, (2+num_more \choose 2)) + # coord_inputs.append(torch.cat([pcs[b][point_idxs_all[b, :, i]] - pcs[b][point_idxs_all[b, :, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, 3 * (2+num_more \choose 2)) + # shot_inputs = torch.stack(shot_inputs, dim=0) # (B, sample_points, feature_dim * (2 + num_more)) + # normal_inputs = torch.stack(normal_inputs, dim=0) # (B, sample_points, (2+num_more \choose 2)) + # coord_inputs = torch.stack(coord_inputs, dim=0) # (B, sample_points, 3 * (2+num_more \choose 2)) + shot_inputs = torch.cat([ + torch.gather(shot_feat, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, shot_feat.shape[-1]))) + for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, N_s * (2 + N_m)) + normal_inputs = torch.cat([torch.max( + torch.sum(torch.gather(pc_normals, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pc_normals.shape[-1]))), + dim=-1, keepdim=True), + torch.sum(-torch.gather(pc_normals, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pc_normals.shape[-1]))), + dim=-1, keepdim=True)) + for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, (2+N_m \choose 2)) + coord_inputs = torch.cat([ + torch.gather(pcs, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pcs.shape[-1]))) - + torch.gather(pcs, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pcs.shape[-1]))) + for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, 3 * (2+N_m \choose 2)) + if has_rgb: + rgb_inputs = torch.cat([ + torch.gather(pc_colors, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_colors.shape[-1]))) + for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, 3 * (2 + N_m)) + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1) + else: + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1) + preds = encoder(inputs) # (B, N_t, (2 + N_r + 2 + 1) * J) + + # voting + batch_pred_translations, batch_pred_rotations, batch_pred_affordances = [], [], [] + pcs_numpy = pcs.cpu().numpy().astype(np.float32) # (B, N, 3) + pc_normals_numpy = pc_normals.cpu().numpy().astype(np.float32) # (B, N, 3) + joint_translations_numpy = joint_translations.numpy().astype(np.float32) # (B, J, 3) + joint_rotations_numpy = joint_rotations.numpy().astype(np.float32) # (B, J, 3) + affordable_positions_numpy = affordable_positions.numpy().astype(np.float32) # (B, J, 3) + point_idxs_numpy = point_idxs_all[:, :, :2].cpu().numpy().astype(np.int32) # (B, N_t, 2) + preds_numpy = preds.cpu().numpy().astype(np.float32) # (B, N_t, (2 + N_r + 2 + 1) * J) + for b in range(B): + pc = pcs_numpy[b] # (N, 3) + pc_normal = pc_normals_numpy[b] # (N, 3) + joint_translation = joint_translations_numpy[b] # (J, 3) + joint_rotation = joint_rotations_numpy[b] # (J, 3) + affordable_position = affordable_positions_numpy[b] # (J, 3) + point_idx = point_idxs_numpy[b] # (N_t, 2) + pred = preds_numpy[b] # (N_t, (2 + N_r + 2 + 1) * J) + pred_tensor = torch.from_numpy(pred) + + pred_translations, pred_rotations, pred_affordances = [], [], [] + for j in range(J): + # conf selection + pred_conf = torch.sigmoid(pred_tensor[:, -1*J+j]) # (N_t,) + not_selected_indices = pred_conf < 0.5 + pred_conf[not_selected_indices] = 0 + # pred_conf[pred_conf > 0] = 1 + # pred_conf[:] = 1 + pred_conf = pred_conf.numpy() + if vis: + visualize_confidence_voting(pred_conf, pc, point_idx, + whether_frame=True, whether_bbox=True, window_name='conf_voting') + import pdb; pdb.set_trace() + + # translation voting + pred_tr = pred[:, 2*j:2*(j+1)] # (N_t, 2) + pred_translation, grid_obj, corners = voting_translation(pc, pred_tr, point_idx, pred_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_translations.append(pred_translation) + + # rotation voting + pred_rot = pred_tensor[:, (2*J+rot_bin_num*j):(2*J+rot_bin_num*(j+1))] # (N_t, rot_bin_num) + pred_rot = torch.softmax(pred_rot, dim=-1) + pred_rot = torch.multinomial(pred_rot, 1).float()[:, 0] # (N_t,) + pred_rot = pred_rot / (rot_bin_num - 1) * np.pi + pred_rot = pred_rot.numpy() + pred_direction, sphere_pts, counts = voting_rotation(pc, pred_rot, point_idx, pred_conf, + rot_candidate_num, angle_tol, voting_num, bmm_size, device, + multi_candidate, candidate_threshold, rotation_cluster, kmeans, + rotation_multi_neighbor, neighbor_threshold) + pred_rotations.append(pred_direction) + + # affordance voting + pred_afford = pred[:, (2*J+rot_bin_num*J+2*j):(2*J+rot_bin_num*J+2*(j+1))] # (N_t, 2) + pred_affordance, agrid_obj, acorners = voting_translation(pc, pred_afford, point_idx, pred_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_affordances.append(pred_affordance) + + translation_errors = calc_translation_error(pred_translation, joint_translation[j], pred_direction, joint_rotation[j]) + if sum(translation_errors) > 20: + translation_outliers.append(translation_errors) + if vis and sum(translation_errors) > 20: + print(f"{translation_errors = }") + indices = np.indices(grid_obj.shape) + indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape)) + votes_list = grid_obj.reshape(-1) + grid_pc = corners[0] + indices_list * resolution + visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color, + gt_translation=joint_translation[j], gt_color=dark_green_color, + pred_translation=pred_translation, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='tr_voting') + import pdb; pdb.set_trace() + direction_error = calc_direction_error(pred_direction, joint_rotation[j]) + if direction_error > 5: + rotation_outliers.append(direction_error) + if vis and direction_error > 5: + print(f"{direction_error = }") + visualize_rotation_voting(sphere_pts, counts, pc, pc_color=light_blue_color, + gt_rotation=joint_rotation[j], gt_color=dark_green_color, + pred_rotation=pred_direction, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='rot_voting') + import pdb; pdb.set_trace() + affordance_error, _, _, _, _ = calc_translation_error(pred_affordance, affordable_position[j], None, None) + if affordance_error > 5: + affordance_outliers.append(affordance_error) + if vis and affordance_error > 5: + print(f"{affordance_error = }") + indices = np.indices(agrid_obj.shape) + indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(agrid_obj.shape)) + votes_list = agrid_obj.reshape(-1) + grid_pc = acorners[0] + indices_list * resolution + visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color, + gt_translation=affordable_position[j], gt_color=dark_green_color, + pred_translation=pred_affordance, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='afford_voting') + import pdb; pdb.set_trace() + if vis: + visualize(pc, pc_color=light_blue_color, pc_normal=pc_normal, + joint_translations=np.array(pred_translations), joint_rotations=np.array(pred_rotations), affordable_positions=np.array(pred_affordances), + joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color, + whether_frame=True, whether_bbox=True, window_name='pred') + import pdb; pdb.set_trace() + batch_pred_translations.append(pred_translations) + batch_pred_rotations.append(pred_rotations) + batch_pred_affordances.append(pred_affordances) + batch_pred_translations = np.array(batch_pred_translations).astype(np.float32) # (B, J, 3) + batch_pred_rotations = np.array(batch_pred_rotations).astype(np.float32) # (B, J, 3) + batch_pred_affordances = np.array(batch_pred_affordances).astype(np.float32) # (B, J, 3) + batch_gt_translations = joint_translations.numpy().astype(np.float32) # (B, J, 3) + batch_gt_rotations = joint_rotations.numpy().astype(np.float32) # (B, J, 3) + batch_gt_affordances = affordable_positions.numpy().astype(np.float32) # (B, J, 3) + batch_translation_errors = calc_translation_error_batch(batch_pred_translations, batch_gt_translations, batch_pred_rotations, batch_gt_rotations) # (B, J) + batch_rotation_errors = calc_direction_error_batch(batch_pred_rotations, batch_gt_rotations) # (B, J) + batch_affordance_errors, _, _, _, _ = calc_translation_error_batch(batch_pred_affordances, batch_gt_affordances, None, None) # (B, J) + translation_distance_errors.append(batch_translation_errors[0]) + translation_along_errors.append(batch_translation_errors[1]) + translation_perp_errors.append(batch_translation_errors[2]) + translation_plane_errors.append(batch_translation_errors[3]) + translation_line_errors.append(batch_translation_errors[4]) + rotation_errors.append(batch_rotation_errors) + affordance_errors.append(batch_affordance_errors) + translation_distance_errors = np.concatenate(translation_distance_errors, axis=0) # (tested_num, J) + translation_along_errors = np.concatenate(translation_along_errors, axis=0) # (tested_num, J) + translation_perp_errors = np.concatenate(translation_perp_errors, axis=0) # (tested_num, J) + translation_plane_errors = np.concatenate(translation_plane_errors, axis=0) # (tested_num, J) + translation_line_errors = np.concatenate(translation_line_errors, axis=0) # (tested_num, J) + rotation_errors = np.concatenate(rotation_errors, axis=0) # (tested_num, J) + affordance_errors = np.concatenate(affordance_errors, axis=0) # (tested_num, J) + + return { + 'translation_distance_errors': translation_distance_errors, + 'translation_along_errors': translation_along_errors, + 'translation_perp_errors': translation_perp_errors, + 'translation_plane_errors': translation_plane_errors, + 'translation_line_errors': translation_line_errors, + 'translation_outliers_num': len(translation_outliers), + 'rotation_errors': rotation_errors, + 'rotation_outliers_num': len(rotation_outliers), + 'affordance_errors': affordance_errors, + 'affordance_outliers_num': len(affordance_outliers) + } + + +@hydra.main(config_path='./configs', config_name='test_config', version_base='1.2') +def test(cfg:DictConfig) -> None: + logger = logging.getLogger('test') + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + output_dir = hydra_cfg['runtime']['output_dir'] + setup_seed(seed=cfg.testing.seed) + trained_path = cfg.trained.path + trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml") + # merge trained_cfg into cfg, cfg has higher priority + cfg = OmegaConf.merge(trained_cfg, cfg) + print(OmegaConf.to_yaml(cfg)) + + # prepare dataset + logger.info("Preparing dataset...") + device = cfg.testing.device + training_path = cfg.dataset.train_path + testing_path = cfg.dataset.test_path + training_categories = cfg.dataset.train_categories + testing_categories = cfg.dataset.test_categories + joint_num = cfg.dataset.joint_num + resolution = cfg.dataset.resolution + receptive_field = cfg.dataset.receptive_field + noise = cfg.dataset.noise + distortion_rate = cfg.dataset.distortion_rate + distortion_level = cfg.dataset.distortion_level + outlier_rate = cfg.dataset.outlier_rate + outlier_level = cfg.dataset.outlier_level + rgb = cfg.dataset.rgb + denoise = cfg.dataset.denoise + normalize = cfg.dataset.normalize + sample_points_num = cfg.dataset.sample_points_num + sample_tuples_num = cfg.algorithm.sampling.sample_tuples_num + tuple_more_num = cfg.algorithm.sampling.tuple_more_num + training_dataset = ArticulationDataset(training_path, training_categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + noise, distortion_rate, distortion_level, outlier_rate, outlier_level, + rgb, denoise, normalize, debug=False, vis=False, is_train=False) + + batch_size = cfg.testing.batch_size + num_workers = cfg.testing.num_workers + training_dataloader = torch.utils.data.DataLoader(training_dataset, pin_memory=True, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + testing_dataset = ArticulationDataset(testing_path, testing_categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + noise, distortion_rate, distortion_level, outlier_rate, outlier_level, + rgb, denoise, normalize, debug=False, vis=False, is_train=False) + testing_dataloader = torch.utils.data.DataLoader(testing_dataset, pin_memory=True, batch_size=batch_size, shuffle=False, num_workers=num_workers) + logger.info("Prepared dataset.") + + # prepare model + logger.info("Preparing model...") + shot_hidden_dims = cfg.algorithm.shot_encoder.hidden_dims + shot_feature_dim = cfg.algorithm.shot_encoder.feature_dim + shot_bn = cfg.algorithm.shot_encoder.bn + shot_ln = cfg.algorithm.shot_encoder.ln + shot_dropout = cfg.algorithm.shot_encoder.dropout + shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim, + shot_bn, shot_ln, shot_dropout) + shot_encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/shot_encoder_latest.pth', map_location=torch.device(device))) + shot_encoder = shot_encoder.cuda(device) + overall_hidden_dims = cfg.algorithm.encoder.hidden_dims + rot_bin_num = cfg.algorithm.voting.rot_bin_num + overall_bn = cfg.algorithm.encoder.bn + overall_ln = cfg.algorithm.encoder.ln + overall_dropout = cfg.algorithm.encoder.dropout + encoder = create_encoder(tuple_more_num, shot_feature_dim, rgb, overall_hidden_dims, rot_bin_num, joint_num, + overall_bn, overall_ln, overall_dropout) + encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/encoder_latest.pth', map_location=torch.device(device))) + encoder = encoder.cuda(device) + logger.info("Prepared model.") + + # testing + voting_num = cfg.algorithm.voting.voting_num + angle_tol = cfg.algorithm.voting.angle_tol + translation2pc = cfg.algorithm.voting.translation2pc + multi_candidate = cfg.algorithm.voting.multi_candidate + candidate_threshold = cfg.algorithm.voting.candidate_threshold + rotation_multi_neighbor = cfg.algorithm.voting.rotation_multi_neighbor + neighbor_threshold = cfg.algorithm.voting.neighbor_threshold + rotation_cluster = cfg.algorithm.voting.rotation_cluster + bmm_size = cfg.algorithm.voting.bmm_size + logger.info("Testing...") + testing_testing_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + testing_testing_results = test_fn(testing_dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, len(testing_dataset), device, vis=cfg.vis) + log_metrics(testing_testing_results, logger, output_dir, tb_writer=None) + + testing_testing_end_time = time.time() + logger.info("Tested.") + logger.info("Testing time: " + str(testing_testing_end_time - testing_testing_start_time)) + + if cfg.testing.training: + logger.info("Testing training...") + testing_training_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + testing_training_results = test_fn(training_dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, len(training_dataset), device, vis=cfg.vis) + log_metrics(testing_training_results, logger, output_dir, tb_writer=None) + + testing_training_end_time = time.time() + logger.info("Tested training.") + logger.info("Testing training time: " + str(testing_training_end_time - testing_training_start_time)) + else: + pass + + +if __name__ == '__main__': + test() diff --git a/test_gt.py b/test_gt.py new file mode 100644 index 0000000..aef0c53 --- /dev/null +++ b/test_gt.py @@ -0,0 +1,209 @@ +import hydra +from omegaconf import DictConfig +import logging +import tqdm +import time +import numpy as np +import torch +from sklearn.cluster import KMeans + +from datasets.rconfmask_afford_point_tuple_dataset import ArticulationDataset +from inference import voting_translation, voting_rotation +from utilities.metrics_utils import calc_translation_error, calc_translation_error_batch, calc_direction_error, calc_direction_error_batch, log_metrics +from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting +from utilities.env_utils import setup_seed +from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color + + +@hydra.main(config_path='./configs', config_name='test_gt_config', version_base='1.2') +def test_gt(cfg:DictConfig) -> None: + logger = logging.getLogger('test_gt') + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + output_dir = hydra_cfg['runtime']['output_dir'] + setup_seed(seed=cfg.general.seed) + + # prepare dataset + logger.info("Preparing dataset...") + device = cfg.general.device + path = cfg.dataset.path + categories = cfg.dataset.categories + joint_num = cfg.dataset.joint_num + resolution = cfg.dataset.resolution + receptive_field = cfg.dataset.receptive_field + denoise = cfg.dataset.denoise + normalize = cfg.dataset.normalize + sample_points_num = cfg.dataset.sample_points_num + sample_tuples_num = cfg.algorithm.sample_tuples_num + tuple_more_num = cfg.algorithm.tuple_more_num + dataset = ArticulationDataset(path, categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + rgb=False, denoise=denoise, normalize=normalize, + debug=False, vis=False, is_train=False) + batch_size = cfg.general.batch_size + num_workers = cfg.general.num_workers + dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers) + logger.info("Prepared dataset.") + + # test + logger.info("GT Testing...") + translation2pc = cfg.algorithm.translation2pc # solve translation voting too much far problem, but maybe sacrifice precision + rotation_cluster = cfg.algorithm.rotation_cluster # solve rotation voting opposite problem, but still cannot solve it completely since maybe only 1 or 2 candidates + if rotation_cluster: + kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto') + else: + kmeans = None + debug = False + multi_candidate = cfg.algorithm.multi_candidate # solve translation and rotation voting discrete problem + candidate_threshold = cfg.algorithm.candidate_threshold + rotation_multi_neighbor = cfg.algorithm.rotation_multi_neighbor # solve rotation voting discrete resolution problem + neighbor_threshold = cfg.algorithm.neighbor_threshold + angle_tol = cfg.algorithm.angle_tol # solve rotation voting discrete resolution problem, but introduce voting opposite problem instead + rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi)) + voting_num = cfg.algorithm.voting_num + bmm_size = cfg.algorithm.bmm_size + with torch.no_grad(): + translation_distance_errors = [] + translation_along_errors = [] + translation_perp_errors = [] + translation_plane_errors = [] + translation_line_errors = [] + translation_outliers = [] + rotation_errors = [] + rotation_outliers = [] + affordance_errors = [] + affordance_outliers = [] + test_gt_start_time = time.time() + for pcs, pc_normals, pc_shots, joint_translations, joint_rotations, affordable_positions, targets_tr, targets_rot, targets_afford, targets_conf, point_idxs_all in tqdm.tqdm(dataloader): + # (B, N, 3), (B, N, 3), (B, N, 352), (B, J, 3), (B, J, 3), (B, J, 3), (B, J, N_t, 2), (B, J, N_t), (B, J, N_t, 2), (B, J, N_t), (B, N_t, 2 + N_m) + B = pcs.shape[0] + N_t = targets_tr.shape[2] + batch_pred_translations, batch_pred_rotations, batch_pred_affordances = [], [], [] + for b in range(B): + pc = pcs[b].numpy().astype(np.float32) # (N, 3) + # pc_normal = pc_normals[b].numpy().astype(np.float32) # (N, 3) + # pc_shot = pc_shots[b].numpy().astype(np.float32) # (N, 352) + joint_translation = joint_translations[b].numpy().astype(np.float32) # (J, 3) + joint_rotation = joint_rotations[b].numpy().astype(np.float32) # (J, 3) + affordable_position = affordable_positions[b].numpy().astype(np.float32) # (J, 3) + target_tr = targets_tr[b].numpy().astype(np.float32) # (J, N_t, 2) + target_rot = targets_rot[b].numpy().astype(np.float32) # (J, N_t) + target_afford = targets_afford[b].numpy().astype(np.float32) # (J, N_t, 2) + target_conf = targets_conf[b].numpy().astype(np.float32) # (J, N_t) + point_idx_all = point_idxs_all[b].numpy().astype(np.int32) # (N_t, 2 + N_m) + + # inference + pred_translations, pred_rotations, pred_affordances = [], [], [] + for j in range(joint_num): + this_target_tr = target_tr[j] # (N_t, 2) + this_target_rot = target_rot[j] # (N_t,) + this_target_afford = target_afford[j] # (N_t, 2) + this_target_conf = target_conf[j] # (N_t,) + + pred_translation, grid_obj, corners = voting_translation(pc, this_target_tr, point_idx_all[:, :2], this_target_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_translations.append(pred_translation) + + pred_direction, sphere_pts, counts = voting_rotation(pc, this_target_rot, point_idx_all[:, :2], this_target_conf, + rot_candidate_num, angle_tol, voting_num, bmm_size, device, + multi_candidate, candidate_threshold, rotation_cluster, kmeans, + rotation_multi_neighbor, neighbor_threshold) + pred_rotations.append(pred_direction) + + pred_affordance, agrid_obj, acorners = voting_translation(pc, this_target_afford, point_idx_all[:, :2], this_target_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_affordances.append(pred_affordance) + + translation_errors = calc_translation_error(pred_translation, joint_translation[j], pred_direction, joint_rotation[j]) + if sum(translation_errors) > 20: + translation_outliers.append(translation_errors) + if debug and sum(translation_errors) > 20: + print(f"{translation_errors = }") + indices = np.indices(grid_obj.shape) + indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape)) + votes_list = grid_obj.reshape(-1) + grid_pc = corners[0] + indices_list * resolution + visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color, + gt_translation=joint_translation[j], gt_color=dark_green_color, + pred_translation=pred_translation, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='tr_voting') + import pdb; pdb.set_trace() + direction_error = calc_direction_error(pred_direction, joint_rotation[j]) + if direction_error > 5: + rotation_outliers.append(direction_error) + if debug and direction_error > 5: + print(f"{direction_error = }") + visualize_rotation_voting(sphere_pts, counts, pc, pc_color=light_blue_color, + gt_rotation=joint_rotation[j], gt_color=dark_green_color, + pred_rotation=pred_direction, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='rot_voting') + import pdb; pdb.set_trace() + affordance_error, _, _, _, _ = calc_translation_error(pred_affordance, affordable_position[j], None, None) + if affordance_error > 5: + affordance_outliers.append(affordance_error) + if debug and affordance_error > 5: + print(f"{affordance_error = }") + indices = np.indices(agrid_obj.shape) + indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(agrid_obj.shape)) + votes_list = agrid_obj.reshape(-1) + grid_pc = acorners[0] + indices_list * resolution + visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color, + gt_translation=affordable_position[j], gt_color=dark_green_color, + pred_translation=pred_affordance, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='afford_voting') + import pdb; pdb.set_trace() + print(np.sum(this_target_conf), np.sum(this_target_conf) / N_t, sum(translation_errors), direction_error, affordance_error) + # if debug: + # visualize(pc, pc_color=light_blue_color, pc_normal=pc_normal, + # joint_translations=np.array(pred_translations), joint_rotations=np.array(pred_rotations), affordable_positions=np.array(pred_affordances), + # joint_axis_colors=red_color, joint_point_colors=dark_red_color, + # whether_frame=True, whether_bbox=True, window_name='pred') + # import pdb; pdb.set_trace() + batch_pred_translations.append(pred_translations) + batch_pred_rotations.append(pred_rotations) + batch_pred_affordances.append(pred_affordances) + batch_pred_translations = np.array(batch_pred_translations).astype(np.float32) # (B, J, 3) + batch_pred_rotations = np.array(batch_pred_rotations).astype(np.float32) # (B, J, 3) + batch_pred_affordances = np.array(batch_pred_affordances).astype(np.float32) # (B, J, 3) + batch_gt_translations = joint_translations.numpy().astype(np.float32) # (B, J, 3) + batch_gt_rotations = joint_rotations.numpy().astype(np.float32) # (B, J, 3) + batch_gt_affordances = affordable_positions.numpy().astype(np.float32) # (B, J, 3) + batch_translation_errors = calc_translation_error_batch(batch_pred_translations, batch_gt_translations, batch_pred_rotations, batch_gt_rotations) # (B, J) + batch_rotation_errors = calc_direction_error_batch(batch_pred_rotations, batch_gt_rotations) # (B, J) + batch_affordance_errors, _, _, _, _ = calc_translation_error_batch(batch_pred_affordances, batch_gt_affordances, None, None) # (B, J) + translation_distance_errors.append(batch_translation_errors[0]) + translation_along_errors.append(batch_translation_errors[1]) + translation_perp_errors.append(batch_translation_errors[2]) + translation_plane_errors.append(batch_translation_errors[3]) + translation_line_errors.append(batch_translation_errors[4]) + rotation_errors.append(batch_rotation_errors) + affordance_errors.append(batch_affordance_errors) + test_gt_end_time = time.time() + translation_distance_errors = np.concatenate(translation_distance_errors, axis=0) + translation_along_errors = np.concatenate(translation_along_errors, axis=0) + translation_perp_errors = np.concatenate(translation_perp_errors, axis=0) + translation_plane_errors = np.concatenate(translation_plane_errors, axis=0) + translation_line_errors = np.concatenate(translation_line_errors, axis=0) + rotation_errors = np.concatenate(rotation_errors, axis=0) + affordance_errors = np.concatenate(affordance_errors, axis=0) + + results_dict = { + 'translation_distance_errors': translation_distance_errors, + 'translation_along_errors': translation_along_errors, + 'translation_perp_errors': translation_perp_errors, + 'translation_plane_errors': translation_plane_errors, + 'translation_line_errors': translation_line_errors, + 'translation_outliers_num': len(translation_outliers), + 'rotation_errors': rotation_errors, + 'rotation_outliers_num': len(rotation_outliers), + 'affordance_errors': affordance_errors, + 'affordance_outliers_num': len(affordance_outliers) + } + log_metrics(results_dict, logger, output_dir, tb_writer=None) + logger.info(f"Time: {test_gt_end_time - test_gt_start_time}") + logger.info("GT Tested.") + + +if __name__ == '__main__': + test_gt() diff --git a/test_real.py b/test_real.py new file mode 100644 index 0000000..6b5e3a6 --- /dev/null +++ b/test_real.py @@ -0,0 +1,356 @@ +from typing import Dict, Union +import hydra +from omegaconf import DictConfig, OmegaConf +import logging +import os +import time +import tqdm +from itertools import combinations +import numpy as np +import torch +import torch.nn as nn +from sklearn.cluster import KMeans + +from datasets.point_tuple_dataset import ArticulationDataset +from models.roartnet import create_shot_encoder, create_encoder +from inference import voting_translation, voting_rotation +from utilities.metrics_utils import calc_translation_error, calc_translation_error_batch, calc_direction_error, calc_direction_error_batch, log_metrics +from utilities.vis_utils import visualize, visualize_translation_voting, visualize_rotation_voting, visualize_confidence_voting +from utilities.env_utils import setup_seed +from utilities.constants import seed, light_blue_color, red_color, dark_red_color, dark_green_color, yellow_color + + +def test_fn(test_dataloader:torch.utils.data.DataLoader, has_rgb:bool, shot_encoder:nn.Module, encoder:nn.Module, + resolution:float, voting_num:int, rot_bin_num:int, angle_tol:float, + translation2pc:bool, multi_candidate:bool, candidate_threshold:float, rotation_cluster:bool, + rotation_multi_neighbor:bool, neighbor_threshold:float, + bmm_size:int, test_num:int, device:int, vis:bool=False) -> Dict[str, Union[np.ndarray, int]]: + if rotation_cluster: + kmeans = KMeans(n_clusters=2, init='k-means++', n_init='auto') + else: + kmeans = None + rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi)) + + tested_num = 0 + with torch.no_grad(): + names = [] + translation_distance_errors = [] + translation_along_errors = [] + translation_perp_errors = [] + translation_plane_errors = [] + translation_line_errors = [] + translation_outliers = [] + rotation_errors = [] + rotation_outliers = [] + affordance_errors = [] + affordance_outliers = [] + for batch_data in tqdm.tqdm(test_dataloader): + if tested_num >= test_num: + break + if has_rgb: + pcs, pc_normals, pc_shots, pc_colors, joint_translations, joint_rotations, affordable_positions, joint_types, point_idxs_all, batch_names = batch_data + pcs, pc_normals, pc_shots, pc_colors, point_idxs_all = \ + pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), pc_colors.cuda(device), point_idxs_all.cuda(device) + else: + pcs, pc_normals, pc_shots, joint_translations, joint_rotations, affordable_positions, joint_types, point_idxs_all, batch_names = batch_data + pcs, pc_normals, pc_shots, point_idxs_all = \ + pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), point_idxs_all.cuda(device) + # (B, N, 3), (B, N, 3), (B, N, 352)(, (B, N, 3)), (B, J, 3), (B, J, 3), (B, J, 3), (B, J), (B, N_t, 2 + N_m), (B,) + B = pcs.shape[0] + N = pcs.shape[1] + J = joint_translations.shape[1] + N_t = point_idxs_all.shape[1] + tested_num += B + + # shot encoder for every point + shot_feat = shot_encoder(pc_shots) # (B, N, N_s) + + # encoder for sampled point tuples + # shot_inputs = torch.cat([shot_feat[point_idxs_all[:, i]] for i in range(0, point_idxs_all.shape[-1])], -1) # (sample_points, feature_dim * (2 + num_more)) + # normal_inputs = torch.cat([torch.max(torch.sum(normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True), + # torch.sum(-normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True)) + # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, (2+num_more \choose 2)) + # coord_inputs = torch.cat([pc[point_idxs_all[:, i]] - pc[point_idxs_all[:, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, 3 * (2+num_more \choose 2)) + # shot_inputs = [] + # normal_inputs = [] + # coord_inputs = [] + # for b in range(pcs.shape[0]): + # shot_inputs.append(torch.cat([shot_feat[b][point_idxs_all[b, :, i]] for i in range(0, point_idxs_all.shape[-1])], dim=-1)) # (sample_points, feature_dim * (2 + num_more)) + # normal_inputs.append(torch.cat([torch.max(torch.sum(normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True), + # torch.sum(-normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True)) + # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, (2+num_more \choose 2)) + # coord_inputs.append(torch.cat([pcs[b][point_idxs_all[b, :, i]] - pcs[b][point_idxs_all[b, :, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, 3 * (2+num_more \choose 2)) + # shot_inputs = torch.stack(shot_inputs, dim=0) # (B, sample_points, feature_dim * (2 + num_more)) + # normal_inputs = torch.stack(normal_inputs, dim=0) # (B, sample_points, (2+num_more \choose 2)) + # coord_inputs = torch.stack(coord_inputs, dim=0) # (B, sample_points, 3 * (2+num_more \choose 2)) + shot_inputs = torch.cat([ + torch.gather(shot_feat, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, shot_feat.shape[-1]))) + for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, N_s * (2 + N_m)) + normal_inputs = torch.cat([torch.max( + torch.sum(torch.gather(pc_normals, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pc_normals.shape[-1]))), + dim=-1, keepdim=True), + torch.sum(-torch.gather(pc_normals, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pc_normals.shape[-1]))), + dim=-1, keepdim=True)) + for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, (2+N_m \choose 2)) + coord_inputs = torch.cat([ + torch.gather(pcs, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pcs.shape[-1]))) - + torch.gather(pcs, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pcs.shape[-1]))) + for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, 3 * (2+N_m \choose 2)) + if has_rgb: + rgb_inputs = torch.cat([ + torch.gather(pc_colors, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_colors.shape[-1]))) + for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, 3 * (2 + N_m)) + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1) + else: + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1) + preds = encoder(inputs) # (B, N_t, (2 + N_r + 2 + 1) * J) + + # voting + batch_pred_translations, batch_pred_rotations, batch_pred_affordances = [], [], [] + pcs_numpy = pcs.cpu().numpy().astype(np.float32) # (B, N, 3) + pc_normals_numpy = pc_normals.cpu().numpy().astype(np.float32) # (B, N, 3) + joint_translations_numpy = joint_translations.numpy().astype(np.float32) # (B, J, 3) + joint_rotations_numpy = joint_rotations.numpy().astype(np.float32) # (B, J, 3) + affordable_positions_numpy = affordable_positions.numpy().astype(np.float32) # (B, J, 3) + point_idxs_numpy = point_idxs_all[:, :, :2].cpu().numpy().astype(np.int32) # (B, N_t, 2) + preds_numpy = preds.cpu().numpy().astype(np.float32) # (B, N_t, (2 + N_r + 2 + 1) * J) + for b in range(B): + pc = pcs_numpy[b] # (N, 3) + pc_normal = pc_normals_numpy[b] # (N, 3) + joint_translation = joint_translations_numpy[b] # (J, 3) + joint_rotation = joint_rotations_numpy[b] # (J, 3) + affordable_position = affordable_positions_numpy[b] # (J, 3) + point_idx = point_idxs_numpy[b] # (N_t, 2) + pred = preds_numpy[b] # (N_t, (2 + N_r + 2 + 1) * J) + pred_tensor = torch.from_numpy(pred) + + pred_translations, pred_rotations, pred_affordances = [], [], [] + for j in range(J): + # conf selection + pred_conf = torch.sigmoid(pred_tensor[:, -1*J+j]) # (N_t,) + not_selected_indices = pred_conf < 0.5 + pred_conf[not_selected_indices] = 0 + # pred_conf[pred_conf > 0] = 1 + pred_conf = pred_conf.numpy() + if vis: + visualize_confidence_voting(pred_conf, pc, point_idx, + whether_frame=True, whether_bbox=True, window_name='conf_voting') + import pdb; pdb.set_trace() + + # translation voting + pred_tr = pred[:, 2*j:2*(j+1)] # (N_t, 2) + pred_translation, grid_obj, corners = voting_translation(pc, pred_tr, point_idx, pred_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_translations.append(pred_translation) + + # rotation voting + pred_rot = pred_tensor[:, (2*J+rot_bin_num*j):(2*J+rot_bin_num*(j+1))] # (N_t, rot_bin_num) + pred_rot = torch.softmax(pred_rot, dim=-1) + pred_rot = torch.multinomial(pred_rot, 1).float()[:, 0] # (N_t,) + pred_rot = pred_rot / (rot_bin_num - 1) * np.pi + pred_rot = pred_rot.numpy() + pred_direction, sphere_pts, counts = voting_rotation(pc, pred_rot, point_idx, pred_conf, + rot_candidate_num, angle_tol, voting_num, bmm_size, device, + multi_candidate, candidate_threshold, rotation_cluster, kmeans, + rotation_multi_neighbor, neighbor_threshold) + pred_rotations.append(pred_direction) + + # affordance voting + pred_afford = pred[:, (2*J+rot_bin_num*J+2*j):(2*J+rot_bin_num*J+2*(j+1))] # (N_t, 2) + pred_affordance, agrid_obj, acorners = voting_translation(pc, pred_afford, point_idx, pred_conf, + resolution, voting_num, device, + translation2pc, multi_candidate, candidate_threshold) + pred_affordances.append(pred_affordance) + + translation_errors = calc_translation_error(pred_translation, joint_translation[j], pred_direction, joint_rotation[j]) + if sum(translation_errors) > 20: + translation_outliers.append(translation_errors) + if vis and sum(translation_errors) > 20: + print(f"{translation_errors = }") + indices = np.indices(grid_obj.shape) + indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(grid_obj.shape)) + votes_list = grid_obj.reshape(-1) + grid_pc = corners[0] + indices_list * resolution + visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color, + gt_translation=joint_translation[j], gt_color=dark_green_color, + pred_translation=pred_translation, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='tr_voting') + import pdb; pdb.set_trace() + direction_error = calc_direction_error(pred_direction, joint_rotation[j]) + if direction_error > 5: + rotation_outliers.append(direction_error) + if vis and direction_error > 5: + print(f"{direction_error = }") + visualize_rotation_voting(sphere_pts, counts, pc, pc_color=light_blue_color, + gt_rotation=joint_rotation[j], gt_color=dark_green_color, + pred_rotation=pred_direction, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='rot_voting') + import pdb; pdb.set_trace() + affordance_error, _, _, _, _ = calc_translation_error(pred_affordance, affordable_position[j], None, None) + if affordance_error > 5: + affordance_outliers.append(affordance_error) + if vis and affordance_error > 5: + print(f"{affordance_error = }") + indices = np.indices(agrid_obj.shape) + indices_list = np.transpose(indices, (1, 2, 3, 0)).reshape(-1, len(agrid_obj.shape)) + votes_list = agrid_obj.reshape(-1) + grid_pc = acorners[0] + indices_list * resolution + visualize_translation_voting(grid_pc, votes_list, pc, pc_color=light_blue_color, + gt_translation=affordable_position[j], gt_color=dark_green_color, + pred_translation=pred_affordance, pred_color=yellow_color, + show_threshold=candidate_threshold, whether_frame=True, whether_bbox=True, window_name='afford_voting') + import pdb; pdb.set_trace() + if vis: + visualize(pc, pc_color=light_blue_color, pc_normal=pc_normal, + joint_translations=np.array(pred_translations), joint_rotations=np.array(pred_rotations), affordable_positions=np.array(pred_affordances), + joint_axis_colors=red_color, joint_point_colors=dark_red_color, affordable_position_colors=dark_green_color, + whether_frame=True, whether_bbox=True, window_name='pred') + import pdb; pdb.set_trace() + batch_pred_translations.append(pred_translations) + batch_pred_rotations.append(pred_rotations) + batch_pred_affordances.append(pred_affordances) + batch_pred_translations = np.array(batch_pred_translations).astype(np.float32) # (B, J, 3) + batch_pred_rotations = np.array(batch_pred_rotations).astype(np.float32) # (B, J, 3) + batch_pred_affordances = np.array(batch_pred_affordances).astype(np.float32) # (B, J, 3) + batch_gt_translations = joint_translations.numpy().astype(np.float32) # (B, J, 3) + batch_gt_rotations = joint_rotations.numpy().astype(np.float32) # (B, J, 3) + batch_gt_affordances = affordable_positions.numpy().astype(np.float32) # (B, J, 3) + batch_translation_errors = calc_translation_error_batch(batch_pred_translations, batch_gt_translations, batch_pred_rotations, batch_gt_rotations) # (B, J) + batch_rotation_errors = calc_direction_error_batch(batch_pred_rotations, batch_gt_rotations) # (B, J) + batch_affordance_errors, _, _, _, _ = calc_translation_error_batch(batch_pred_affordances, batch_gt_affordances, None, None) # (B, J) + translation_distance_errors.append(batch_translation_errors[0]) + translation_along_errors.append(batch_translation_errors[1]) + translation_perp_errors.append(batch_translation_errors[2]) + translation_plane_errors.append(batch_translation_errors[3]) + translation_line_errors.append(batch_translation_errors[4]) + rotation_errors.append(batch_rotation_errors) + affordance_errors.append(batch_affordance_errors) + names.extend(batch_names) + translation_distance_errors = np.concatenate(translation_distance_errors, axis=0) # (tested_num, J) + translation_along_errors = np.concatenate(translation_along_errors, axis=0) # (tested_num, J) + translation_perp_errors = np.concatenate(translation_perp_errors, axis=0) # (tested_num, J) + translation_plane_errors = np.concatenate(translation_plane_errors, axis=0) # (tested_num, J) + translation_line_errors = np.concatenate(translation_line_errors, axis=0) # (tested_num, J) + rotation_errors = np.concatenate(rotation_errors, axis=0) # (tested_num, J) + affordance_errors = np.concatenate(affordance_errors, axis=0) # (tested_num, J) + + return { + 'names': names, + 'translation_distance_errors': translation_distance_errors, + 'translation_along_errors': translation_along_errors, + 'translation_perp_errors': translation_perp_errors, + 'translation_plane_errors': translation_plane_errors, + 'translation_line_errors': translation_line_errors, + 'translation_outliers_num': len(translation_outliers), + 'rotation_errors': rotation_errors, + 'rotation_outliers_num': len(rotation_outliers), + 'affordance_errors': affordance_errors, + 'affordance_outliers_num': len(affordance_outliers) + } + + +@hydra.main(config_path='./configs', config_name='test_real_config', version_base='1.2') +def test_real(cfg:DictConfig) -> None: + logger = logging.getLogger('test_real') + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + output_dir = hydra_cfg['runtime']['output_dir'] + setup_seed(seed=cfg.testing.seed) + trained_path = cfg.trained.path + trained_cfg = OmegaConf.load(f"{trained_path}/.hydra/config.yaml") + # merge trained_cfg into cfg, cfg has higher priority + cfg = OmegaConf.merge(trained_cfg, cfg) + print(OmegaConf.to_yaml(cfg)) + + # prepare dataset + logger.info("Preparing dataset...") + device = cfg.testing.device + path = cfg.dataset.path + instances = cfg.dataset.instances + joint_num = cfg.dataset.joint_num + resolution = cfg.dataset.resolution + receptive_field = cfg.dataset.receptive_field + rgb = cfg.dataset.rgb + denoise = cfg.dataset.denoise + normalize = cfg.dataset.normalize + sample_points_num = cfg.dataset.sample_points_num + sample_tuples_num = cfg.algorithm.sampling.sample_tuples_num + tuple_more_num = cfg.algorithm.sampling.tuple_more_num + dataset = ArticulationDataset(path, instances, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + rgb, denoise, normalize, debug=False, vis=False, is_train=False) + + batch_size = cfg.testing.batch_size + num_workers = cfg.testing.num_workers + dataloader = torch.utils.data.DataLoader(dataset, pin_memory=True, batch_size=batch_size, shuffle=False, num_workers=num_workers) + logger.info("Prepared dataset.") + + # prepare model + logger.info("Preparing model...") + shot_hidden_dims = cfg.algorithm.shot_encoder.hidden_dims + shot_feature_dim = cfg.algorithm.shot_encoder.feature_dim + shot_bn = cfg.algorithm.shot_encoder.bn + shot_ln = cfg.algorithm.shot_encoder.ln + shot_droput = cfg.algorithm.shot_encoder.dropout + shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim, + shot_bn, shot_ln, shot_droput) + shot_encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/shot_encoder_latest.pth', map_location=torch.device(device))) + shot_encoder = shot_encoder.cuda(device) + overall_hidden_dims = cfg.algorithm.encoder.hidden_dims + rot_bin_num = cfg.algorithm.voting.rot_bin_num + overall_bn = cfg.algorithm.encoder.bn + overall_ln = cfg.algorithm.encoder.ln + overall_dropout = cfg.algorithm.encoder.dropout + encoder = create_encoder(tuple_more_num, shot_feature_dim, rgb, overall_hidden_dims, rot_bin_num, joint_num, + overall_bn, overall_ln, overall_dropout) + encoder.load_state_dict(torch.load(f'{os.path.join(trained_path, "weights")}/encoder_latest.pth', map_location=torch.device(device))) + encoder = encoder.cuda(device) + logger.info("Prepared model.") + + # testing + voting_num = cfg.algorithm.voting.voting_num + angle_tol = cfg.algorithm.voting.angle_tol + translation2pc = cfg.algorithm.voting.translation2pc + multi_candidate = cfg.algorithm.voting.multi_candidate + candidate_threshold = cfg.algorithm.voting.candidate_threshold + rotation_multi_neighbor = cfg.algorithm.voting.rotation_multi_neighbor + neighbor_threshold = cfg.algorithm.voting.neighbor_threshold + rotation_cluster = cfg.algorithm.voting.rotation_cluster + bmm_size = cfg.algorithm.voting.bmm_size + logger.info("Testing...") + testing_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + testing_results = test_fn(dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, len(dataset), device, vis=cfg.vis) + log_metrics(testing_results, logger, output_dir, tb_writer=None) + + testing_end_time = time.time() + logger.info("Tested.") + logger.info("Testing time: " + str(testing_end_time - testing_start_time)) + + +if __name__ == '__main__': + test_real() diff --git a/train.py b/train.py new file mode 100644 index 0000000..6da9818 --- /dev/null +++ b/train.py @@ -0,0 +1,378 @@ +import hydra +from omegaconf import DictConfig +import logging +import os +from itertools import combinations +import time +import tqdm +import numpy as np +import torch +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from warmup_scheduler import GradualWarmupScheduler + +from datasets.rconfmask_afford_point_tuple_dataset import ArticulationDataset +from models.roartnet import create_shot_encoder, create_encoder +from test import test_fn +from utilities.env_utils import setup_seed +from utilities.metrics_utils import AverageMeter, log_metrics +from utilities.data_utils import real2prob + + +@hydra.main(config_path='./configs', config_name='train_config', version_base='1.2') +def train(cfg:DictConfig) -> None: + logger = logging.getLogger('train') + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + output_dir = hydra_cfg['runtime']['output_dir'] + setup_seed(seed=cfg.training.seed) + + # prepare dataset + logger.info("Preparing dataset...") + device = cfg.training.device + training_path = cfg.dataset.train_path + training_categories = cfg.dataset.train_categories + joint_num = cfg.dataset.joint_num + resolution = cfg.dataset.resolution + receptive_field = cfg.dataset.receptive_field + rgb = cfg.dataset.rgb + denoise = cfg.dataset.denoise + normalize = cfg.dataset.normalize + sample_points_num = cfg.dataset.sample_points_num + sample_tuples_num = cfg.algorithm.sampling.sample_tuples_num + tuple_more_num = cfg.algorithm.sampling.tuple_more_num + training_dataset = ArticulationDataset(training_path, training_categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + rgb, denoise, normalize, debug=False, vis=False, is_train=True) + + batch_size = cfg.training.batch_size + num_workers = cfg.training.num_workers + training_dataloader = torch.utils.data.DataLoader(training_dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers) + + testing_training_dataset = ArticulationDataset(training_path, training_categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + rgb, denoise, normalize, debug=False, vis=False, is_train=False) + testing_training_dataloader = torch.utils.data.DataLoader(testing_training_dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers) + testing_path = cfg.dataset.test_path + testing_categories = cfg.dataset.test_categories + testing_testing_dataset = ArticulationDataset(testing_path, testing_categories, joint_num, resolution, receptive_field, + sample_points_num, sample_tuples_num, tuple_more_num, + rgb, denoise, normalize, debug=False, vis=False, is_train=False) + testing_testing_dataloader = torch.utils.data.DataLoader(testing_testing_dataset, pin_memory=True, batch_size=batch_size, shuffle=True, num_workers=num_workers) + logger.info("Prepared dataset.") + + # prepare model + logger.info("Preparing model...") + shot_hidden_dims = cfg.algorithm.shot_encoder.hidden_dims + shot_feature_dim = cfg.algorithm.shot_encoder.feature_dim + shot_bn = cfg.algorithm.shot_encoder.bn + shot_ln = cfg.algorithm.shot_encoder.ln + shot_droput = cfg.algorithm.shot_encoder.dropout + shot_encoder = create_shot_encoder(shot_hidden_dims, shot_feature_dim, + shot_bn, shot_ln, shot_droput) + shot_encoder = shot_encoder.cuda(device) + overall_hidden_dims = cfg.algorithm.encoder.hidden_dims + rot_bin_num = cfg.algorithm.voting.rot_bin_num + overall_bn = cfg.algorithm.encoder.bn + overall_ln = cfg.algorithm.encoder.ln + overall_dropout = cfg.algorithm.encoder.dropout + encoder = create_encoder(tuple_more_num, shot_feature_dim, rgb, overall_hidden_dims, rot_bin_num, joint_num, + overall_bn, overall_ln, overall_dropout) + encoder = encoder.cuda(device) + logger.info("Prepared model.") + + # optimize + logger.info("Optimizing...") + training_start_time = time.time() + lr = cfg.training.lr + weight_decay = cfg.training.weight_decay + epoch_num = cfg.training.epoch_num + lambda_rot = cfg.training.lambda_rot + lambda_afford = cfg.training.lambda_afford + lambda_conf = cfg.training.lambda_conf + voting_num = cfg.algorithm.voting.voting_num + angle_tol = cfg.algorithm.voting.angle_tol + translation2pc = cfg.algorithm.voting.translation2pc + multi_candidate = cfg.algorithm.voting.multi_candidate + candidate_threshold = cfg.algorithm.voting.candidate_threshold + rotation_multi_neighbor = cfg.algorithm.voting.rotation_multi_neighbor + neighbor_threshold = cfg.algorithm.voting.neighbor_threshold + rotation_cluster = cfg.algorithm.voting.rotation_cluster + bmm_size = cfg.algorithm.voting.bmm_size + opt = optim.Adam([*encoder.parameters(), *shot_encoder.parameters()], lr=lr, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epoch_num, eta_min=lr/100.0) + scheduler_warmup = GradualWarmupScheduler(opt, multiplier=1, total_epoch=epoch_num//20, after_scheduler=scheduler) + tb_writer = SummaryWriter(log_dir=os.path.join(output_dir, 'tb')) + iteration = 0 + for epoch in range(epoch_num): + if epoch == 0: + opt.zero_grad() + opt.step() + scheduler_warmup.step() + + loss_meter = AverageMeter() + loss_tr_meter = AverageMeter() + loss_rot_meter = AverageMeter() + loss_afford_meter = AverageMeter() + loss_conf_meter = AverageMeter() + + # train + shot_encoder.train() + encoder.train() + logger.info("epoch: " + str(epoch) + " lr: " + str(scheduler_warmup.get_last_lr()[0])) + tb_writer.add_scalar('lr', scheduler_warmup.get_last_lr()[0], epoch) + with tqdm.tqdm(training_dataloader) as t: + data_num = 0 + data_loader_start_time = time.time() + for batch_data in t: + if rgb: + pcs, pc_normals, pc_shots, pc_colors, target_trs, target_rots, target_affords, target_confs, point_idxs_all = batch_data + pcs, pc_normals, pc_shots, pc_colors, target_trs, target_rots, target_affords, target_confs, point_idxs_all = \ + pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), pc_colors.cuda(device), target_trs.cuda(device), target_rots.cuda(device), target_affords.cuda(device), target_confs.cuda(device), point_idxs_all.cuda(device) + else: + pcs, pc_normals, pc_shots, target_trs, target_rots, target_affords, target_confs, point_idxs_all = batch_data + pcs, pc_normals, pc_shots, target_trs, target_rots, target_affords, target_confs, point_idxs_all = \ + pcs.cuda(device), pc_normals.cuda(device), pc_shots.cuda(device), target_trs.cuda(device), target_rots.cuda(device), target_affords.cuda(device), target_confs.cuda(device), point_idxs_all.cuda(device) + # (B, N, 3), (B, N, 3), (B, N, 352)(, (B, N, 3)), (B, J, N_t, 2), (B, J, N_t), (B, J, N_t, 2), (B, J, N_t), (B, N_t, 2 + N_m) + B = pcs.shape[0] + N = pcs.shape[1] + J = target_trs.shape[1] + N_t = target_trs.shape[2] + data_num += B + + opt.zero_grad() + dataloader_end_time = time.time() + if cfg.debug: + logger.warning("Data loader time: " + str(dataloader_end_time - data_loader_start_time)) + + forward_start_time = time.time() + # shot encoder for every point + shot_feat = shot_encoder(pc_shots) # (B, N, N_s) + + # encoder for sampled point tuples + # shot_inputs = torch.cat([shot_feat[point_idxs_all[:, i]] for i in range(0, point_idxs_all.shape[-1])], -1) # (sample_points, feature_dim * (2 + num_more)) + # normal_inputs = torch.cat([torch.max(torch.sum(normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True), + # torch.sum(-normal[point_idxs_all[:, i]] * normal[point_idxs_all[:, j]], dim=-1, keepdim=True)) + # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, (2+num_more \choose 2)) + # coord_inputs = torch.cat([pc[point_idxs_all[:, i]] - pc[point_idxs_all[:, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], -1) # (sample_points, 3 * (2+num_more \choose 2)) + # shot_inputs = [] + # normal_inputs = [] + # coord_inputs = [] + # for b in range(pcs.shape[0]): + # shot_inputs.append(torch.cat([shot_feat[b][point_idxs_all[b, :, i]] for i in range(0, point_idxs_all.shape[-1])], dim=-1)) # (sample_points, feature_dim * (2 + num_more)) + # normal_inputs.append(torch.cat([torch.max(torch.sum(normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True), + # torch.sum(-normals[b][point_idxs_all[b, :, i]] * normals[b][point_idxs_all[b, :, j]], dim=-1, keepdim=True)) + # for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, (2+num_more \choose 2)) + # coord_inputs.append(torch.cat([pcs[b][point_idxs_all[b, :, i]] - pcs[b][point_idxs_all[b, :, j]] for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1)) # (sample_points, 3 * (2+num_more \choose 2)) + # shot_inputs = torch.stack(shot_inputs, dim=0) # (B, sample_points, feature_dim * (2 + num_more)) + # normal_inputs = torch.stack(normal_inputs, dim=0) # (B, sample_points, (2+num_more \choose 2)) + # coord_inputs = torch.stack(coord_inputs, dim=0) # (B, sample_points, 3 * (2+num_more \choose 2)) + shot_inputs = torch.cat([ + torch.gather(shot_feat, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, shot_feat.shape[-1]))) + for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, N_s * (2 + N_m)) + normal_inputs = torch.cat([torch.max( + torch.sum(torch.gather(pc_normals, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pc_normals.shape[-1]))), + dim=-1, keepdim=True), + torch.sum(-torch.gather(pc_normals, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_normals.shape[-1]))) * + torch.gather(pc_normals, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pc_normals.shape[-1]))), + dim=-1, keepdim=True)) + for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, (2+N_m \choose 2)) + coord_inputs = torch.cat([ + torch.gather(pcs, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pcs.shape[-1]))) - + torch.gather(pcs, 1, + point_idxs_all[:, :, j:j+1].expand( + (B, N_t, pcs.shape[-1]))) + for (i, j) in combinations(np.arange(point_idxs_all.shape[-1]), 2)], dim=-1) # (B, N_t, 3 * (2+N_m \choose 2)) + if rgb: + rgb_inputs = torch.cat([ + torch.gather(pc_colors, 1, + point_idxs_all[:, :, i:i+1].expand( + (B, N_t, pc_colors.shape[-1]))) + for i in range(point_idxs_all.shape[-1])], dim=-1) # (B, N_t, 3 * (2 + N_m)) + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs, rgb_inputs], dim=-1) + else: + inputs = torch.cat([coord_inputs, normal_inputs, shot_inputs], dim=-1) + preds = encoder(inputs) # (B, N_t, (2 + N_r + 2 + 1) * J) + forward_end_time = time.time() + if cfg.debug: + logger.warning("Forward time: " + str(forward_end_time - forward_start_time)) + + backward_start_time = time.time() + loss = 0 + # regression loss for translation for topk + pred_trs = preds[:, :, 0:(2 * J)] # (B, N_t, 2*J) + pred_trs = pred_trs.reshape((B, N_t, J, 2)) # (B, N_t, J, 2) + pred_trs = pred_trs.transpose(1, 2) # (B, J, N_t, 2) + loss_tr_ = torch.mean((pred_trs - target_trs) ** 2, dim=-1) # (B, J, N_t) + loss_tr_ = loss_tr_ * target_confs + loss_tr = loss_tr_[loss_tr_ > 0] + loss_tr = torch.mean(loss_tr) + loss += loss_tr + loss_tr_meter.update(loss_tr.item()) + tb_writer.add_scalar('loss/loss_tr', loss_tr.item(), iteration) + + # classification loss for rotation for topk + pred_rots = preds[:, :, (2 * J):(-3 * J)] # (B, N_t, rot_bin_num*J) + pred_rots = pred_rots.reshape((B, N_t, J, rot_bin_num)) # (B, N_t, J, rot_bin_num) + pred_rots = pred_rots.transpose(1, 2) # (B, J, N_t, rot_bin_num) + pred_rots_ = F.log_softmax(pred_rots, dim=-1) # (B, J, N_t, rot_bin_num) + target_rots_ = real2prob(target_rots, np.pi, rot_bin_num, circular=False) # (B, J, N_t, rot_bin_num) + loss_rot_ = torch.sum(F.kl_div(pred_rots_, target_rots_, reduction='none'), dim=-1) # (B, J, N_t) + loss_rot_ = loss_rot_ * target_confs + loss_rot = loss_rot_[loss_rot_ > 0] + loss_rot = torch.mean(loss_rot) + loss_rot *= lambda_rot + loss += loss_rot + loss_rot_meter.update(loss_rot.item()) + tb_writer.add_scalar('loss/loss_rot', loss_rot.item(), iteration) + + # regression loss for affordance for topk + pred_affords = preds[:, :, (-3 * J):-J] # (B, N_t, 2*J) + pred_affords = pred_affords.reshape((B, N_t, J, 2)) # (B, N_t, J, 2) + pred_affords = pred_affords.transpose(1, 2) # (B, J, N_t, 2) + loss_afford_ = torch.mean((pred_affords - target_affords) ** 2, dim=-1) # (B, J, N_t) + loss_afford_ = loss_afford_ * target_confs + loss_afford = loss_afford_[loss_afford_ > 0] + loss_afford = torch.mean(loss_afford) + loss_afford *= lambda_afford + loss += loss_afford + loss_afford_meter.update(loss_afford.item()) + tb_writer.add_scalar('loss/loss_afford', loss_afford.item(), iteration) + + # classification loss for goodness + pred_confs = preds[:, :, -J:] # (B, N_t, J) + pred_confs = pred_confs.transpose(1, 2) # (B, J, N_t) + loss_conf = F.binary_cross_entropy_with_logits(pred_confs, target_confs, reduction='none') # (B, J, N_t) + loss_conf = torch.mean(loss_conf) + loss_conf *= lambda_conf + loss += loss_conf + loss_conf_meter.update(loss_conf.item()) + tb_writer.add_scalar('loss/loss_conf', loss_conf.item(), iteration) + + loss.backward(retain_graph=False) + # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.) + # torch.nn.utils.clip_grad_norm_(shot_encoder.parameters(), 1.) + opt.step() + backward_end_time = time.time() + if cfg.debug: + logger.warning("Backward time: " + str(backward_end_time - backward_start_time)) + + loss_meter.update(loss.item()) + tb_writer.add_scalar('loss/loss', loss.item(), iteration) + + t.set_postfix(epoch=epoch, loss=loss_meter.avg, tr=loss_tr_meter.avg, rot=loss_rot_meter.avg, afford=loss_afford_meter.avg, conf=loss_conf_meter.avg) + + iteration += 1 + data_loader_start_time = time.time() + scheduler_warmup.step() + tb_writer.add_scalar('loss/loss_tr_avg', loss_tr_meter.avg, epoch) + tb_writer.add_scalar('loss/loss_rot_avg', loss_rot_meter.avg, epoch) + tb_writer.add_scalar('loss/loss_afford_avg', loss_afford_meter.avg, epoch) + tb_writer.add_scalar('loss/loss_conf_avg', loss_conf_meter.avg, epoch) + tb_writer.add_scalar('loss/loss_avg', loss_meter.avg, epoch) + logger.info("training loss: " + str(loss_tr_meter.avg) + " + " + str(loss_rot_meter.avg) + " + " + \ + str(loss_afford_meter.avg) + " + " + str(loss_conf_meter.avg) + " = " + str(loss_meter.avg) + ", data num: " + str(data_num)) + + # save model + if epoch % (epoch_num // 10) == 0: + os.makedirs(os.path.join(output_dir, 'weights'), exist_ok=True) + torch.save(encoder.state_dict(), os.path.join(output_dir, 'weights', 'encoder_latest.pth')) + torch.save(shot_encoder.state_dict(), os.path.join(output_dir, 'weights', 'shot_encoder_latest.pth')) + + # validation + if cfg.training.val_training and epoch % (epoch_num // 10) == 0: + logger.info("Validating training...") + validating_training_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + validating_training_num = cfg.training.val_training_num if cfg.training.val_training_num > 0 else len(testing_training_dataset) + validating_training_results = test_fn(testing_training_dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, validating_training_num, device, vis=False) + log_metrics(validating_training_results, logger, output_dir, tb_writer, epoch, 'training') + + validating_training_end_time = time.time() + logger.info("Validated training.") + logger.info("Validating training time: " + str(validating_training_end_time - validating_training_start_time)) + + if cfg.training.val_testing and epoch % (epoch_num // 10) == 0: + logger.info("Validating testing...") + validating_testing_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + validating_testing_num = cfg.training.val_testing_num if cfg.training.val_testing_num > 0 else len(testing_testing_dataset) + validating_testing_results = test_fn(testing_testing_dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, validating_testing_num, device, vis=False) + log_metrics(validating_testing_results, logger, output_dir, tb_writer, epoch, 'testing') + + validating_testing_end_time = time.time() + logger.info("Validated testing.") + logger.info("Validating testing time: " + str(validating_testing_end_time - validating_testing_start_time)) + + training_end_time = time.time() + logger.info("Optimized.") + logger.info("Training time: " + str(training_end_time - training_start_time)) + + # test + if cfg.training.test_train: + logger.info("Testing training...") + testing_training_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + testing_training_results = test_fn(testing_training_dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, len(testing_training_dataset), device, vis=False) + log_metrics(testing_training_results, logger, output_dir, tb_writer, epoch_num, 'training') + + testing_training_end_time = time.time() + logger.info("Tested training.") + logger.info("Testing training time: " + str(testing_training_end_time - testing_training_start_time)) + + if cfg.training.test_test: + logger.info("Testing testing...") + testing_testing_start_time = time.time() + shot_encoder.eval() + encoder.eval() + + testing_testing_results = test_fn(testing_testing_dataloader, rgb, shot_encoder, encoder, + resolution, voting_num, rot_bin_num, angle_tol, + translation2pc, multi_candidate, candidate_threshold, rotation_cluster, + rotation_multi_neighbor, neighbor_threshold, + bmm_size, len(testing_testing_dataset), device, vis=False) + log_metrics(testing_testing_results, logger, output_dir, tb_writer, epoch_num, 'testing') + + testing_testing_end_time = time.time() + logger.info("Tested testing.") + logger.info("Testing testing time: " + str(testing_testing_end_time - testing_testing_start_time)) + + # save model + os.makedirs(os.path.join(output_dir, 'weights'), exist_ok=True) + torch.save(encoder.state_dict(), os.path.join(output_dir, 'weights', 'encoder_latest.pth')) + torch.save(shot_encoder.state_dict(), os.path.join(output_dir, 'weights', 'shot_encoder_latest.pth')) + + +if __name__ == '__main__': + train() diff --git a/utilities/__init__.py b/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utilities/constants.py b/utilities/constants.py new file mode 100644 index 0000000..f96637a --- /dev/null +++ b/utilities/constants.py @@ -0,0 +1,15 @@ +import numpy as np + +EPS = 1e-3 + +seed = 42 + +light_blue_color = np.array([126/255, 208/255, 248/255]) +dark_green_color = np.array([42/255, 157/255, 142/255]) +gray_color = np.array([204/255, 204/255, 204/255]) +dark_gray_color = np.array([179/255, 179/255, 179/255]) +red_color = np.array([232/255, 17/255, 35/255]) +dark_red_color = np.array([111/255, 15/255, 21/255]) +yellow_color = np.array([255/255, 220/255, 126/255]) + +max_grasp_width = 0.08 diff --git a/utilities/data_utils.py b/utilities/data_utils.py new file mode 100644 index 0000000..230d256 --- /dev/null +++ b/utilities/data_utils.py @@ -0,0 +1,553 @@ +from typing import List, Tuple, Union, Dict, Optional +import math +import numpy as np +from scipy.spatial.transform import Rotation as srot +from scipy.optimize import least_squares, linear_sum_assignment +import torch +import xml.etree.ElementTree as ET + +from .metrics_utils import calc_translation_error_batch, calc_direction_error_batch + + +def read_joints_from_urdf_file(urdf_file): + tree_urdf = ET.parse(urdf_file) + root_urdf = tree_urdf.getroot() + + joint_dict = {} + for joint in root_urdf.iter('joint'): + joint_name = joint.attrib['name'] + joint_type = joint.attrib['type'] + for child in joint.iter('child'): + joint_child = child.attrib['link'] + for parent in joint.iter('parent'): + joint_parent = parent.attrib['link'] + for origin in joint.iter('origin'): + if 'xyz' in origin.attrib: + joint_xyz = [float(x) for x in origin.attrib['xyz'].split()] + else: + joint_xyz = [0, 0, 0] + if 'rpy' in origin.attrib: + joint_rpy = [float(x) for x in origin.attrib['rpy'].split()] + else: + joint_rpy = [0, 0, 0] + if joint_type == 'prismatic' or joint_type == 'revolute' or joint_type == 'continuous': + for axis in joint.iter('axis'): + joint_axis = [float(x) for x in axis.attrib['xyz'].split()] + else: + joint_axis = None + if joint_type == 'prismatic' or joint_type == 'revolute': + for limit in joint.iter('limit'): + joint_limit = [float(limit.attrib['lower']), float(limit.attrib['upper'])] + else: + joint_limit = None + + joint_dict[joint_name] = { + 'type': joint_type, + 'parent': joint_parent, + 'child': joint_child, + 'xyz': joint_xyz, + 'rpy': joint_rpy, + 'axis': joint_axis, + 'limit': joint_limit + } + + return joint_dict + + +def pc_normalize(pc:np.ndarray, normalize_method:str) -> Tuple[np.ndarray, np.ndarray, float]: + if normalize_method == 'none': + pc_normalized = pc + center = np.array([0., 0., 0.]).astype(pc.dtype) + scale = 1. + elif normalize_method == 'mean': + center = np.mean(pc, axis=0) + pc_normalized = pc - center + scale = np.max(np.sqrt(np.sum(pc_normalized ** 2, axis=1))) + pc_normalized = pc_normalized / scale + elif normalize_method == 'bound': + center = (np.max(pc, axis=0) + np.min(pc, axis=0)) / 2 + pc_normalized = pc - center + scale = np.max(np.sqrt(np.sum(pc_normalized ** 2, axis=1))) + pc_normalized = pc_normalized / scale + elif normalize_method == 'median': + center = np.median(pc, axis=0) + pc_normalized = pc - center + scale = np.max(np.sqrt(np.sum(pc_normalized ** 2, axis=1))) + pc_normalized = pc_normalized / scale + else: + raise NotImplementedError + + return (pc_normalized, center, scale) + +def joints_normalize(joint_translations:Optional[np.ndarray], joint_rotations:Optional[np.ndarray], center:np.ndarray, scale:float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + if joint_translations is None and joint_rotations is None: + return (None, None) + + J = joint_translations.shape[0] if joint_translations is not None else joint_rotations.shape[0] + joint_translations_normalized, joint_rotations_normalized = [], [] + for j in range(J): + if joint_translations is not None: + joint_translation_normalized = (joint_translations[j] - center) / scale + joint_translations_normalized.append(joint_translation_normalized) + if joint_rotations is not None: + # joint_axis_normalized = (joint_rotations[j] - center) / scal + # joint_rotation_normalized = joint_axis_normalized - joint_translation_normalized + # joint_rotation_normalized /= np.linalg.norm(joint_rotation_normalized) + joint_rotation_normalized = joint_rotations[j].copy() + joint_rotation_normalized /= np.linalg.norm(joint_rotation_normalized) + joint_rotations_normalized.append(joint_rotation_normalized) + if joint_translations is not None: + joint_translations_normalized = np.array(joint_translations_normalized).astype(joint_translations.dtype) + else: + joint_translations_normalized = None + if joint_rotations is not None: + joint_rotations_normalized = np.array(joint_rotations_normalized).astype(joint_rotations.dtype) + else: + joint_rotations_normalized = None + return (joint_translations_normalized, joint_rotations_normalized) + +def joints_denormalize(joint_translations_normalized:Optional[np.ndarray], joint_rotations_normalized:Optional[np.ndarray], center:np.ndarray, scale:float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + if joint_translations_normalized is None and joint_rotations_normalized is None: + return (None, None) + + J = joint_translations_normalized.shape[0] if joint_translations_normalized is not None else joint_rotations_normalized.shape[0] + joint_translations, joint_rotations = [], [] + for j in range(J): + if joint_translations_normalized is not None: + joint_translation = joint_translations_normalized[j] * scale + center + joint_translations.append(joint_translation) + if joint_rotations_normalized is not None: + # joint_axis = (joint_translations_normalized[j] + joint_rotations_normalized[j]) * scale + center + # joint_rotation = joint_axis - joint_translation + joint_rotation = joint_rotations_normalized[j].copy() + joint_rotation /= np.linalg.norm(joint_rotation) + joint_rotations.append(joint_rotation) + if joint_translations_normalized is not None: + joint_translations = np.array(joint_translations).astype(joint_translations_normalized.dtype) + else: + joint_translations = None + if joint_rotations_normalized is not None: + joint_rotations = np.array(joint_rotations).astype(joint_rotations_normalized.dtype) + else: + joint_rotations = None + return (joint_translations, joint_rotations) + +def joint_denormalize(joint_translation_normalized:Optional[np.ndarray], joint_rotation_normalized:Optional[np.ndarray], center:np.ndarray, scale:float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + if joint_translation_normalized is not None: + joint_translation = joint_translation_normalized * scale + center + else: + joint_translation = None + if joint_rotation_normalized is not None: + # joint_axis = (joint_translation_normalized + joint_rotation_normalized) * scale + center + # joint_rotation = joint_axis - joint_translation + joint_rotation = joint_rotation_normalized.copy() + joint_rotation /= np.linalg.norm(joint_rotation) + else: + joint_rotation = None + return (joint_translation, joint_rotation) + +def transform_pc(pc_camera:np.ndarray, c2w:np.ndarray) -> np.ndarray: + # pc_camera: (N, 3), c2w: (4, 4) + pc_camera_hm = np.concatenate([pc_camera, np.ones((pc_camera.shape[0], 1), dtype=pc_camera.dtype)], axis=-1) # (N, 4) + pc_world_hm = pc_camera_hm @ c2w.T # (N, 4) + pc_world = pc_world_hm[:, :3] # (N, 3) + return pc_world + +def transform_dir(dir_camera:np.ndarray, c2w:np.ndarray) -> np.ndarray: + # dir_camera: (N, 3), c2w: (4, 4) + dir_camera_hm = np.concatenate([dir_camera, np.zeros((dir_camera.shape[0], 1), dtype=dir_camera.dtype)], axis=-1) # (N, 4) + dir_world_hm = dir_camera_hm @ c2w.T # (N, 4) + dir_world = dir_world_hm[:, :3] # (N, 3) + return dir_world + + +def generate_target_tr(pc:np.ndarray, o:np.ndarray, point_idxs:np.ndarray) -> np.ndarray: + a = pc[point_idxs[:, 0]] # (N_t, 3) + b = pc[point_idxs[:, 1]] # (N_t, 3) + pdist = a - b + pdist_unit = pdist / (np.linalg.norm(pdist, axis=-1, keepdims=True) + 1e-7) + proj_len = np.sum((a - o) * pdist_unit, -1) + oc = a - o - proj_len[..., None] * pdist_unit + dist2o = np.linalg.norm(oc, axis=-1) + target_tr = np.stack([proj_len, dist2o], -1) + return target_tr.astype(np.float32).reshape((-1, 2)) + +def generate_target_rot(pc:np.ndarray, axis:np.ndarray, point_idxs:np.ndarray) -> np.ndarray: + a = pc[point_idxs[:, 0]] # (N_t, 3) + b = pc[point_idxs[:, 1]] # (N_t, 3) + pdist = a - b + pdist_unit = pdist / (np.linalg.norm(pdist, axis=-1, keepdims=True) + 1e-7) + cos = np.sum(pdist_unit * axis, axis=-1) + cos = np.clip(cos, -1., 1.) + target_rot = np.arccos(cos) + return target_rot.astype(np.float32).reshape((-1,)) + + +def farthest_point_sample(point:np.ndarray, npoint:int) -> Tuple[np.ndarray, np.ndarray]: + """ + Input: + xyz: pointcloud data, [N, D] + npoint: number of samples + Return: + point: sampled pointcloud, [npoint, D] + centroids: sampled pointcloud index + """ + N, D = point.shape + xyz = point[:,:3] + centroids = np.zeros((npoint,)) + distance = np.ones((N,)) * 1e10 + farthest = np.random.randint(0, N) + for i in range(npoint): + centroids[i] = farthest + centroid = xyz[farthest, :] + dist = np.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = np.argmax(distance, -1) + centroids = centroids.astype(np.int32) + point = point[centroids] + return (point, centroids) + + +def fibonacci_sphere(samples:int) -> List[Tuple[float, float, float]]: + points = [] + phi = math.pi * (3. - math.sqrt(5.)) # golden angle in radians + + for i in range(samples): + y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1 + radius = math.sqrt(1 - y * y) # radius at y + + theta = phi * i # golden angle increment + + x = math.cos(theta) * radius + z = math.sin(theta) * radius + + points.append((x, y, z)) + + return points + + +def real2prob(val:Union[torch.Tensor, np.ndarray], max_val:float, num_bins:int, circular:bool=False) -> Union[torch.Tensor, np.ndarray]: + is_torch = isinstance(val, torch.Tensor) + if is_torch: + res = torch.zeros((*val.shape, num_bins), dtype=val.dtype).to(val.device) + else: + res = np.zeros((*val.shape, num_bins), dtype=val.dtype) + + if not circular: + interval = max_val / (num_bins - 1) + if is_torch: + low = torch.clamp(torch.floor(val / interval).long(), max=num_bins - 2) + else: + low = np.clip(np.floor(val / interval).astype(np.int64), a_min=None, a_max=num_bins - 2) + high = low + 1 + # assert torch.all(low >= 0) and torch.all(high < num_bins) + + # huge memory + if is_torch: + res.scatter_(-1, low[..., None], torch.unsqueeze(1. - (val / interval - low), -1)) + res.scatter_(-1, high[..., None], 1. - torch.gather(res, -1, low[..., None])) + else: + np.put_along_axis(res, low[..., None], np.expand_dims(1. - (val / interval - low), -1), -1) + np.put_along_axis(res, high[..., None], 1. - np.take_along_axis(res, low[..., None], -1), -1) + # res[..., low] = 1. - (val / interval - low) + # res[..., high] = 1. - res[..., low] + # assert torch.all(0 <= res[..., low]) and torch.all(1 >= res[..., low]) + return res + else: + interval = max_val / num_bins + if is_torch: + val_new = torch.clone(val) + else: + val_new = val.copy() + val_new[val < interval / 2] += max_val + res = real2prob(val_new - interval / 2, max_val, num_bins + 1) + res[..., 0] += res[..., -1] + return res[..., :-1] + + +def pc_ncs(pc:np.ndarray, bbox_min:np.ndarray, bbox_max:np.ndarray) -> np.ndarray: + return (pc - (bbox_min + bbox_max) / 2 + 0.5 * (bbox_max - bbox_min)) / (bbox_max - bbox_min) + +def joints_ncs(joint_translations:np.ndarray, joint_rotations:Optional[np.ndarray], bbox_min:np.ndarray, bbox_max:np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]: + return ((joint_translations - (bbox_min + bbox_max) / 2 + 0.5 * (bbox_max - bbox_min)) / (bbox_max - bbox_min), joint_rotations.copy() if joint_rotations is not None else None) + + +def rotate_points_with_rotvec(points, rot_vecs): + """Rotate points by given rotation vectors. + + Rodrigues' rotation formula is used. + """ + theta = np.linalg.norm(rot_vecs, axis=1)[:, np.newaxis] + with np.errstate(invalid='ignore'): + v = rot_vecs / theta + v = np.nan_to_num(v) + dot = np.sum(points * v, axis=1)[:, np.newaxis] + cos_theta = np.cos(theta) + sin_theta = np.sin(theta) + + return cos_theta * points + sin_theta * np.cross(v, points) + dot * (1 - cos_theta) * v + + +def scale_pts(source, target): + # compute scaling factor between source: [N x 3], target: [N x 3] + pdist_s = source.reshape(source.shape[0], 1, 3) - source.reshape(1, source.shape[0], 3) + A = np.sqrt(np.sum(pdist_s**2, 2)).reshape(-1) + pdist_t = target.reshape(target.shape[0], 1, 3) - target.reshape(1, target.shape[0], 3) + b = np.sqrt(np.sum(pdist_t**2, 2)).reshape(-1) + scale = np.dot(A, b) / (np.dot(A, A)+1e-6) + return scale + +def rotate_pts(source, target): + # compute rotation between source: [N x 3], target: [N x 3] + # pre-centering + source = source - np.mean(source, 0, keepdims=True) + target = target - np.mean(target, 0, keepdims=True) + M = np.matmul(target.T, source) + U, D, Vh = np.linalg.svd(M, full_matrices=True) + d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 + if d: + D[-1] = -D[-1] + U[:, -1] = -U[:, -1] + R = np.matmul(U, Vh) + return R + + +def objective_eval_t(params, x0, y0, x1, y1, joints, isweight=True): + # params: [:3] R0, [3:] R1 + # x0: N x 3, y0: N x 3, x1: M x 3, y1: M x 3, R0: 1 x 3, R1: 1 x 3, joints: K x 3 + rotvec0 = params[:3].reshape((1,3)) + rotvec1 = params[3:].reshape((1,3)) + res0 = y0 - rotate_points_with_rotvec(x0, rotvec0) + res1 = y1 - rotate_points_with_rotvec(x1, rotvec1) + res_R = rotvec0 - rotvec1 + if isweight: + res0 /= x0.shape[0] + res1 /= x1.shape[0] + return np.concatenate((res0, res1, res_R), 0).ravel() + +def objective_eval_r(params, x0, y0, x1, y1, joints, isweight=True): + # params: [:3] R0, [3:] R1 + # x0: N x 3, y0: N x 3, x1: M x 3, y1: M x 3, R0: 1 x 3, R1: 1 x 3, joints: K x 3 + rotvec0 = params[:3].reshape((1,3)) + rotvec1 = params[3:].reshape((1,3)) + res0 = y0 - rotate_points_with_rotvec(x0, rotvec0) + res1 = y1 - rotate_points_with_rotvec(x1, rotvec1) + res_joint = rotate_points_with_rotvec(joints, rotvec0) - rotate_points_with_rotvec(joints, rotvec1) + if isweight: + res0 /= x0.shape[0] + res1 /= x1.shape[0] + res_joint /= joints.shape[0] + return np.concatenate((res0, res1, res_joint), 0).ravel() + + +def joint_transformation_estimator(dataset:Dict[str, np.ndarray], joint_type:str, best_inliers:Optional[Tuple[np.ndarray, np.ndarray]]=None) -> Optional[Dict[str, np.ndarray]]: + nsource0 = dataset['source0'].shape[0] + nsource1 = dataset['source1'].shape[0] + if nsource0 < 3 or nsource1 < 3: + return None + if best_inliers is None: + sample_idx0 = np.random.randint(nsource0, size=3) + sample_idx1 = np.random.randint(nsource1, size=3) + else: + sample_idx0 = best_inliers[0] + sample_idx1 = best_inliers[1] + + source0 = dataset['source0'][sample_idx0, :] + target0 = dataset['target0'][sample_idx0, :] + source1 = dataset['source1'][sample_idx1, :] + target1 = dataset['target1'][sample_idx1, :] + + scale0 = scale_pts(source0, target0) + scale1 = scale_pts(source1, target1) + scale0_inv = scale_pts(target0, source0) + scale1_inv = scale_pts(target1, source1) + + target0_scaled_centered = scale0_inv*target0 + target0_scaled_centered -= np.mean(target0_scaled_centered, 0, keepdims=True) + source0_centered = source0 - np.mean(source0, 0, keepdims=True) + + target1_scaled_centered = scale1_inv*target1 + target1_scaled_centered -= np.mean(target1_scaled_centered, 0, keepdims=True) + source1_centered = source1 - np.mean(source1, 0, keepdims=True) + + joint_points0 = np.ones_like(np.linspace(0, 1, num = np.min((source0.shape[0], source1.shape[0]))+1 )[1:].reshape((-1, 1)))*dataset['joint_direction'].reshape((1, 3)) + + R0 = rotate_pts(source0_centered, target0_scaled_centered) + R1 = rotate_pts(source1_centered, target1_scaled_centered) + + rotvec0 = srot.from_matrix(R0).as_rotvec() + rotvec1 = srot.from_matrix(R1).as_rotvec() + if joint_type == 'prismatic': + res = least_squares(objective_eval_t, np.hstack((rotvec0, rotvec1)), verbose=0, ftol=1e-4, method='lm', + args=(source0_centered, target0_scaled_centered, source1_centered, target1_scaled_centered, joint_points0, False)) + elif joint_type == 'revolute': + res = least_squares(objective_eval_r, np.hstack((rotvec0, rotvec1)), verbose=0, ftol=1e-4, method='lm', + args=(source0_centered, target0_scaled_centered, source1_centered, target1_scaled_centered, joint_points0, False)) + else: + raise ValueError + R0 = srot.from_rotvec(res.x[:3]).as_matrix() + R1 = srot.from_rotvec(res.x[3:]).as_matrix() + + translation0 = np.mean(target0.T-scale0*np.matmul(R0, source0.T), 1) + translation1 = np.mean(target1.T-scale1*np.matmul(R1, source1.T), 1) + + if np.isnan(translation0).any() or np.isnan(translation1).any() or np.isnan(R0).any() or np.isnan(R0).any(): + return None + + jtrans = dict() + jtrans['rotation0'] = R0 + jtrans['scale0'] = scale0 + jtrans['translation0'] = translation0 + jtrans['rotation1'] = R1 + jtrans['scale1'] = scale1 + jtrans['translation1'] = translation1 + return jtrans + +def joint_transformation_verifier(dataset:Dict[str, np.ndarray], model:Dict[str, np.ndarray], inlier_th:float) -> Tuple[float, Tuple[np.ndarray, np.ndarray]]: + res0 = dataset['target0'].T - model['scale0'] * np.matmul( model['rotation0'], dataset['source0'].T ) - model['translation0'].reshape((3, 1)) + inliers0 = np.sqrt(np.sum(res0**2, 0)) < inlier_th + res1 = dataset['target1'].T - model['scale1'] * np.matmul( model['rotation1'], dataset['source1'].T ) - model['translation1'].reshape((3, 1)) + inliers1 = np.sqrt(np.sum(res1**2, 0)) < inlier_th + score = ( np.sum(inliers0)/res0.shape[0] + np.sum(inliers1)/res1.shape[0] ) / 2 + return (score, (inliers0, inliers1)) + +def ransac(dataset:Dict[str, np.ndarray], inlier_threshold:float, iteration_num:int, joint_type:str) -> Optional[Dict[str, np.ndarray]]: + best_model = None + best_score = -np.inf + best_inliers = None + for i in range(iteration_num): + cur_model = joint_transformation_estimator(dataset, joint_type=joint_type, best_inliers=None) + if cur_model is None: + return None + cur_score, cur_inliers = joint_transformation_verifier(dataset, cur_model, inlier_threshold) + if cur_score > best_score: + best_model = cur_model + best_inliers = cur_inliers + best_score = cur_score + best_model = joint_transformation_estimator(dataset, joint_type=joint_type, best_inliers=best_inliers) + return best_model + + +def match_joints(proposal_joints:List[Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]], + gt_joint_translations:np.ndarray, gt_joint_rotations:np.ndarray, match_metric:str, + has_affordance:bool=False) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]: + # TODO: affordance-related not very well + proposal_num = len(proposal_joints) # N + gt_num = gt_joint_translations.shape[0] # M + if proposal_num == 0: + pred_joint_translations = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype) + pred_joint_rotations = np.ones_like(gt_joint_rotations, dtype=gt_joint_rotations.dtype) + pred_joint_rotations = pred_joint_rotations / np.linalg.norm(pred_joint_rotations, axis=-1, keepdims=True) + if has_affordance: + pred_affordances = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype) + return (pred_joint_translations, pred_joint_rotations, pred_affordances) + else: + return (pred_joint_translations, pred_joint_rotations) + + proposal_joint_translations = np.array([proposal_joints[i][0] for i in range(proposal_num)]) # (N, 3) + proposal_joint_rotations = np.array([proposal_joints[i][1] for i in range(proposal_num)]) # (N, 3) + if has_affordance: + proposal_affordances = np.array([proposal_joints[i][2] for i in range(proposal_num)]) # (N, 3) + else: + pass + cost_matrix = np.zeros((gt_num, proposal_num)) # (M, N) + for gt_idx in range(gt_num): + gt_joint_translation = gt_joint_translations[gt_idx].reshape((1, 3)).repeat(proposal_num, axis=0) # (N, 3) + gt_joint_rotation = gt_joint_rotations[gt_idx].reshape((1, 3)).repeat(proposal_num, axis=0) # (N, 3) + translation_errors = calc_translation_error_batch(proposal_joint_translations, gt_joint_translation, proposal_joint_rotations, gt_joint_rotation) # (N,) + direction_errors = calc_direction_error_batch(proposal_joint_rotations, gt_joint_rotation) # (N,) + if match_metric == 'tr_dist': + cost_matrix[gt_idx] = translation_errors[0] + elif match_metric == 'tr_along': + cost_matrix[gt_idx] = translation_errors[1] + elif match_metric == 'tr_perp': + cost_matrix[gt_idx] = translation_errors[2] + elif match_metric == 'tr_plane': + cost_matrix[gt_idx] = translation_errors[3] + elif match_metric == 'tr_line': + cost_matrix[gt_idx] = translation_errors[4] + elif match_metric == 'tr_mean': + cost_matrix[gt_idx] = (translation_errors[0] + translation_errors[1] + translation_errors[2] + translation_errors[3] + translation_errors[4]) / 5 + elif match_metric == 'dir': + cost_matrix[gt_idx] = direction_errors + elif match_metric == 'tr_dist_dir': + cost_matrix[gt_idx] = (translation_errors[0] + direction_errors) / 2 + elif match_metric == 'tr_line_dir': + cost_matrix[gt_idx] = (translation_errors[4] + direction_errors) / 2 + elif match_metric == 'tr_mean_dir': + cost_matrix[gt_idx] = ((translation_errors[0] + translation_errors[1] + translation_errors[2] + translation_errors[3] + translation_errors[4]) / 5 + direction_errors) / 2 + else: + raise NotImplementedError + row_ind, col_ind = linear_sum_assignment(cost_matrix) + + pred_joint_translations = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype) + pred_joint_rotations = np.ones_like(gt_joint_rotations, dtype=gt_joint_rotations.dtype) + if has_affordance: + pred_affordances = np.zeros_like(gt_joint_translations, dtype=gt_joint_translations.dtype) + else: + pass + for gt_idx, proposal_idx in zip(row_ind, col_ind): + pred_joint_translations[gt_idx] = proposal_joint_translations[proposal_idx] + pred_joint_rotations[gt_idx] = proposal_joint_rotations[proposal_idx] + if has_affordance: + pred_affordances[gt_idx] = proposal_affordances[proposal_idx] + else: + pass + pred_joint_rotations = pred_joint_rotations / np.linalg.norm(pred_joint_rotations, axis=-1, keepdims=True) + if has_affordance: + return (pred_joint_translations, pred_joint_rotations, pred_affordances) + else: + return (pred_joint_translations, pred_joint_rotations) + + +def pc_noise(pc:np.ndarray, distortion_rate:float, distortion_level:float, + outlier_rate:float, outlier_level:float) -> np.ndarray: + num_points = pc.shape[0] + pc_center = (np.max(pc, axis=0) + np.min(pc, axis=0)) / 2 + pc_scale = np.max(pc, axis=0) - np.min(pc, axis=0) + pc_noised = pc.copy() + # distortion noise + distortion_indices = np.random.choice(num_points, int(distortion_rate * num_points), replace=False) + distortion_noise = np.random.normal(0.0, distortion_level * pc_scale, pc_noised[distortion_indices].shape) + pc_noised[distortion_indices] = pc_noised[distortion_indices] + distortion_noise + # outlier noise + outlier_indices = np.random.choice(num_points, int(outlier_rate * num_points), replace=False) + outlier_noise = np.random.uniform(pc_center - outlier_level * pc_scale, pc_center + outlier_level * pc_scale, pc_noised[outlier_indices].shape) + pc_noised[outlier_indices] = outlier_noise + # print(num_points, int(distortion_rate * num_points), int(outlier_rate * num_points)) + return pc_noised + + +if __name__ == '__main__': + # angle_tol = 1.5 + angle_tol = 0.35 + rot_candidate_num = int(4 * np.pi / (angle_tol / 180 * np.pi)) + print(rot_candidate_num) + sphere_pts = fibonacci_sphere(rot_candidate_num) # (N, 3) + sphere_pts = np.array(sphere_pts) + + # figure out the angles between neighboring 8 points on the sphere + min_angles = [] + max_angles = [] + mean_angles = [] + for i in range(len(sphere_pts)): + a = sphere_pts[i] # (3,) + cos = np.dot(a, sphere_pts.T) # (N,) + cos = np.clip(cos, -1., 1.) + theta = np.arccos(cos) # (N,) + idxs = np.argsort(theta)[:9] # (9,) + thetas = theta[idxs[1:]] / np.pi * 180 # (8,) + min_angles.append(np.min(thetas)) + max_angles.append(np.max(thetas)) + mean_angles.append(np.mean(thetas)) + + print(np.min(min_angles)) + print(np.max(min_angles)) + print(np.mean(min_angles)) + print(np.min(max_angles)) + print(np.max(max_angles)) + print(np.mean(max_angles)) + print(np.min(mean_angles)) + print(np.max(mean_angles)) + print(np.mean(mean_angles)) diff --git a/utilities/env_utils.py b/utilities/env_utils.py new file mode 100644 index 0000000..eb91e50 --- /dev/null +++ b/utilities/env_utils.py @@ -0,0 +1,13 @@ +import os +import random +import numpy as np +import torch + +def setup_seed(seed:int) -> None: + random.seed(seed) + np.random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True diff --git a/utilities/metrics_utils.py b/utilities/metrics_utils.py new file mode 100644 index 0000000..04a6742 --- /dev/null +++ b/utilities/metrics_utils.py @@ -0,0 +1,651 @@ +from typing import Tuple, List, Union, Dict, Optional +import os +import numpy as np +import torch +import torch.nn.functional as F +from logging import Logger +from torch.utils.tensorboard import SummaryWriter +import pandas as pd + + +def calc_translation_error(pred_p:np.ndarray, gt_p:np.ndarray, pred_e:Optional[np.ndarray], gt_e:Optional[np.ndarray]) -> Tuple[float, float, float, float, float]: + def calc_plane_error(pred_translation:np.ndarray, pred_direction:np.ndarray, + gt_translation:np.ndarray, gt_direction:np.ndarray) -> float: + if abs(np.dot(pred_direction, gt_direction)) < 1e-3: + # parallel to the plane + # point-to-line distance + dist = np.linalg.norm(np.cross(pred_direction, gt_translation - pred_translation)) + return dist + # gt_direction \dot (x - gt_translation) = 0 + # x = pred_translation + t * pred_direction + t = np.dot(gt_translation - pred_translation, gt_direction) / np.dot(pred_direction, gt_direction) + x = pred_translation + t * pred_direction + dist = np.linalg.norm(x - gt_translation) + return dist + def calc_line_error(pred_translation:np.ndarray, pred_direction:np.ndarray, + gt_translation:np.ndarray, gt_direction:np.ndarray) -> float: + orth_vect = np.cross(gt_direction, pred_direction) + p = gt_translation - pred_translation + if np.linalg.norm(orth_vect) < 1e-3: + dist = np.linalg.norm(np.cross(p, gt_direction)) / np.linalg.norm(gt_direction) + else: + dist = abs(np.dot(orth_vect, p)) / np.linalg.norm(orth_vect) + return dist + distance_error = np.linalg.norm(pred_p - gt_p) * 100.0 + if pred_e is None or gt_e is None: + along_error = 0.0 + perp_error = 0.0 + plane_error = 0.0 + line_error = 0.0 + else: + along_error = abs(np.dot(pred_p - gt_p, gt_e)) * 100.0 + perp_error = np.sqrt(distance_error**2 - along_error**2) + plane_error = calc_plane_error(pred_p, pred_e, gt_p, gt_e) * 100.0 + line_error = calc_line_error(pred_p, pred_e, gt_p, gt_e) * 100.0 + + return (distance_error, along_error, perp_error, plane_error, line_error) + +def calc_translation_error_batch(pred_ps:np.ndarray, gt_ps:np.ndarray, pred_es:Optional[np.ndarray], gt_es:Optional[np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def calc_plane_error_batch(pred_translations:np.ndarray, pred_directions:np.ndarray, + gt_translations:np.ndarray, gt_directions:np.ndarray) -> np.ndarray: + dists = np.zeros(pred_translations.shape[:-1], dtype=pred_translations.dtype) + flag = np.abs(np.sum(pred_directions * gt_directions, axis=-1)) < 1e-3 + not_flag = np.logical_not(flag) + + dists[flag] = np.linalg.norm(np.cross(gt_translations[flag] - pred_translations[flag], pred_translations[flag]), axis=-1) + + ts = np.sum((gt_translations[not_flag] - pred_translations[not_flag]) * gt_directions[not_flag], axis=-1) / np.sum(pred_directions[not_flag] * gt_directions[not_flag], axis=-1) + xs = pred_translations[not_flag] + ts[..., None] * pred_directions[not_flag] + dists[not_flag] = np.linalg.norm(xs - gt_translations[not_flag], axis=-1) + return dists + def calc_line_error_batch(pred_translations:np.ndarray, pred_directions:np.ndarray, + gt_translations:np.ndarray, gt_directions:np.ndarray) -> np.ndarray: + dists = np.zeros(pred_translations.shape[:-1], dtype=pred_translations.dtype) + orth_vects = np.cross(gt_directions, pred_directions) + ps = gt_translations - pred_translations + flag = np.linalg.norm(orth_vects, axis=-1) < 1e-3 + not_flag = np.logical_not(flag) + + dists[flag] = np.linalg.norm(np.cross(ps[flag], gt_directions[flag]), axis=-1) / np.linalg.norm(gt_directions[flag], axis=-1) + + dists[not_flag] = np.abs(np.sum(orth_vects[not_flag] * ps[not_flag], axis=-1)) / np.linalg.norm(orth_vects[not_flag], axis=-1) + return dists + distance_errors = np.linalg.norm(pred_ps - gt_ps, axis=-1) * 100.0 + if pred_es is None or gt_es is None: + along_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype) + perp_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype) + plane_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype) + line_errors = np.zeros(distance_errors.shape, dtype=distance_errors.dtype) + else: + along_errors = np.abs(np.sum((pred_ps - gt_ps) * gt_es, axis=-1)) * 100.0 + perp_errors = np.sqrt(distance_errors**2 - along_errors**2) + plane_errors = calc_plane_error_batch(pred_ps, pred_es, gt_ps, gt_es) * 100.0 + line_errors = calc_line_error_batch(pred_ps, pred_es, gt_ps, gt_es) * 100.0 + + return (distance_errors, along_errors, perp_errors, plane_errors, line_errors) + +def calc_direction_error(pred_e:np.ndarray, gt_e:np.ndarray) -> float: + cos_theta = np.dot(pred_e, gt_e) + cos_theta = np.clip(cos_theta, -1., 1.) + angle_radian = np.arccos(cos_theta) * 180.0 / np.pi + return angle_radian + +def calc_direction_error_batch(pred_es:np.ndarray, gt_es:np.ndarray) -> np.ndarray: + cos_thetas = np.sum(pred_es * gt_es, axis=-1) + cos_thetas = np.clip(cos_thetas, -1., 1.) + angle_radians = np.arccos(cos_thetas) * 180.0 / np.pi + return angle_radians + + +def calculate_miou_loss(pred_seg:torch.Tensor, gt_seg_onehot:torch.Tensor) -> torch.Tensor: + # pred_seg: (B, N, C), gt_seg_onehot: (B, N, C) + dot = torch.sum(pred_seg * gt_seg_onehot, dim=-2) + denominator = torch.sum(pred_seg, dim=-2) + torch.sum(gt_seg_onehot, dim=-2) - dot + mIoU = dot / (denominator + 1e-7) # (B, C) + return torch.mean(1.0 - mIoU) + +def calculate_ncs_loss(pred_ncs:torch.Tensor, gt_ncs:torch.Tensor, gt_seg_onehot:torch.Tensor) -> torch.Tensor: + # pred_ncs: (B, N, 3 * C), gt_ncs: (B, N, 3), gt_seg_onehot: (B, N, C) + loss_ncs = 0 + ncs_splits = torch.split(pred_ncs, split_size_or_sections=3, dim=-1) # (C,), (B, N, 3) + seg_splits = torch.split(gt_seg_onehot, split_size_or_sections=1, dim=2) # (C,), (B, N, 1) + C = len(ncs_splits) + for i in range(C): + diff_l2 = torch.norm(ncs_splits[i] - gt_ncs, dim=-1) # (B, N) + loss_ncs += torch.mean(seg_splits[i][:, :, 0] * diff_l2, dim=-1) # (B,) + return torch.mean(loss_ncs) + +def calculate_heatmap_loss(pred_heatmap:torch.Tensor, gt_heatmap:torch.Tensor, mask:torch.Tensor) -> torch.Tensor: + # pred_heatmap: (B, N), gt_heatmap: (B, N), mask: (B, N) + loss_heatmap = torch.abs(pred_heatmap - gt_heatmap)[mask > 0] + loss_heatmap = torch.mean(loss_heatmap) + loss_heatmap[torch.isnan(loss_heatmap)] = 0.0 + return loss_heatmap + +def calculate_unitvec_loss(pred_unitvec:torch.Tensor, gt_unitvec:torch.Tensor, mask:torch.Tensor) -> torch.Tensor: + # pred_unitvec: (B, N, 3), gt_unitvec: (B, N, 3), mask: (B, N) + loss_unitvec = torch.norm(pred_unitvec - gt_unitvec, dim=-1)[mask > 0] + return torch.mean(loss_unitvec) + # pt_diff = pred_unitvec - gt_unitvec + # pt_dist = torch.sum(pt_diff.abs(), dim=-1) # (B, N) + # loss_pt_dist = torch.mean(pt_dist[mask > 0]) + # pred_unitvec_normalized = pred_unitvec / (torch.norm(pred_unitvec, dim=-1, keepdim=True) + 1e-7) + # gt_unitvec_normalized = gt_unitvec / (torch.norm(gt_unitvec, dim=-1, keepdim=True) + 1e-7) + # dir_diff = torch.sum(-(pred_unitvec_normalized * gt_unitvec_normalized), dim=-1) # (B, N) + # loss_dir_diff = torch.mean(dir_diff[mask > 0]) + # loss_unitvec = loss_pt_dist + loss_dir_diff + # return loss_unitvec + + +def focal_loss(inputs:torch.Tensor, targets:torch.Tensor, alpha:Optional[torch.Tensor]=None, gamma:float=2.0, ignore_index:Optional[int]=None) -> torch.Tensor: + # inputs: (N, C), targets: (N,) + if ignore_index is not None: + valid_mask = targets != ignore_index + targets = targets[valid_mask] + + if targets.shape[0] == 0: + return torch.tensor(0.0).to(dtype=inputs.dtype, device=inputs.device) + + inputs = inputs[valid_mask] + + log_p = torch.clamp(torch.log(inputs + 1e-7), max=0.0) + if ignore_index is not None: + ce_loss = F.nll_loss( + log_p, targets, weight=alpha, ignore_index=ignore_index, reduction="none" + ) + else: + ce_loss = F.nll_loss( + log_p, targets, weight=alpha, reduction="none" + ) + log_p_t = log_p.gather(1, targets[:, None]).squeeze(-1) + loss = ce_loss * ((1 - log_p_t.exp()) ** gamma) + loss = loss.mean() + return loss + +def dice_loss(input:torch.Tensor, target:torch.Tensor) -> torch.Tensor: + def one_hot(labels:torch.Tensor, num_classes:int, device:Optional[torch.device]=None, dtype:Optional[torch.dtype]=None) -> torch.Tensor: + if not isinstance(labels, torch.Tensor): + raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}") + + if not labels.dtype == torch.int64: + raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}") + + if num_classes < 1: + raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes)) + + shape = labels.shape + one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype) + return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + 1e-6 + # input: (N, C, H, W), target: (N, H, W) + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}") + + if not input.shape[-2:] == target.shape[-2:]: + raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {target.shape}") + + if not input.device == target.device: + raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}") + + # create the labels one hot tensor + target_one_hot = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) + + # compute the actual dice score + dims = (1, 2, 3) + intersection = torch.sum(input * target_one_hot, dims) + cardinality = torch.sum(input + target_one_hot, dims) + + dice_score = 2.0 * intersection / (cardinality + 1e-8) + + return torch.mean(-dice_score + 1.0) + + +def calculate_offset_loss(pred_offsets:torch.Tensor, gt_offsets:torch.Tensor, function_masks:torch.Tensor) -> torch.Tensor: + # pred_offsets: (B, N, 3), gt_offsets: (B, N, 3), function_masks: (B, N) + pt_diff = pred_offsets - gt_offsets + pt_dist = torch.sum(pt_diff.abs(), dim=-1) # (B, N) + loss_pt_offset_dist = pt_dist[function_masks < 2].mean() + loss_pt_offset_dist[torch.isnan(loss_pt_offset_dist)] = 0.0 + gt_offsets_norm = torch.norm(gt_offsets, dim=-1, keepdim=True) + gt_offsets_normalized = gt_offsets / (gt_offsets_norm + 1e-8) + pred_offsets_norm = torch.norm(pred_offsets, dim=-1, keepdim=True) + pred_offsets_normalized = pred_offsets / (pred_offsets_norm + 1e-8) + dir_diff = -(gt_offsets_normalized * pred_offsets_normalized).sum(dim=-1) # (B, N) + loss_offset_dir = dir_diff[function_masks < 2].mean() + loss_offset_dir[torch.isnan(loss_offset_dir)] = 0.0 + loss_offset = loss_offset_dir + loss_pt_offset_dist + return loss_offset + +def calculate_dir_loss(pred_dirs:torch.Tensor, gt_dirs:torch.Tensor, function_masks:torch.Tensor) -> torch.Tensor: + # pred_dirs: (B, N, 3), gt_dirs: (B, N, 3), function_masks: (B, N) + pt_diff = pred_dirs - gt_dirs + pt_dist = torch.sum(pt_diff.abs(), dim=-1) # (B, N) + dis_loss = pt_dist[function_masks < 2].mean() + dis_loss[torch.isnan(dis_loss)] = 0.0 + dir_loss = -(pred_dirs * gt_dirs).sum(dim=-1) # (B, N) + dir_loss = dir_loss[function_masks < 2].mean() + dir_loss[torch.isnan(dir_loss)] = 0.0 + loss = dis_loss + dir_loss + return loss + + +def calc_pose_error(pose1:np.ndarray, pose2:np.ndarray) -> Tuple[float, float]: + error_matrix = np.dot(np.linalg.inv(pose1), pose2) + translation_error = error_matrix[:3, 3] + rotation_error = np.arccos(np.clip((np.trace(error_matrix[:3, :3]) - 1) / 2, -1, 1)) + return (np.linalg.norm(translation_error), rotation_error) + + +def invaffordance_metrics(grasp_translation:np.ndarray, grasp_rotation:np.ndarray, grasp_score:float, affordable_position:np.ndarray, + joint_base:np.ndarray, joint_direction:np.ndarray, joint_type:int) -> Tuple[float, float, float]: + if joint_type == 0: + # revolute + l2_dist = np.linalg.norm(grasp_translation - affordable_position) + # normal = np.cross(joint_base - affordable_position, joint_base + joint_direction - affordable_position) + # normal = normal / np.linalg.norm(normal) + # plane_dist = abs(np.dot(grasp_translation - affordable_position, normal)) + # return (1.0 - grasp_score,) + # return (l2_dist, plane_dist, 1.0 - grasp_score) + return (l2_dist,) + elif joint_type == 1: + # prismatic + l2_dist = np.linalg.norm(grasp_translation - affordable_position) + # plane_dist = abs(np.dot(grasp_translation - affordable_position, joint_direction)) + # return (1.0 - grasp_score,) + # return (l2_dist, plane_dist, 1.0 - grasp_score) + return (l2_dist,) + else: + raise ValueError(f"Invalid joint_type: {joint_type}") + +def invaffordances2affordance(invaffordances:List[Tuple[float, float, float]], sort:bool=False) -> List[float]: + if len(invaffordances) == 0: + return [] + if len(invaffordances) == 1: + return [1.0] + # from list of tuple to dict of list, k: metrics index, v: invaffordance list + invaffordances_dict = {k: [v[k] for v in invaffordances] for k in range(len(invaffordances[0]))} + if sort: + # sort each list in dict + affordances_dict = {} + for k in invaffordances_dict.keys(): + affordances_dict[k] = len(invaffordances_dict[k]) - 1 - np.argsort(np.argsort(invaffordances_dict[k])) + affordances_dict[k] /= len(invaffordances) - 1 + else: + # normalize each list in dict + affordances_dict = {} + for k in invaffordances_dict.keys(): + affordances_dict[k] = (np.max(invaffordances_dict[k]) - invaffordances_dict[k]) / (np.max(invaffordances_dict[k]) - np.min(invaffordances_dict[k])) + # merge into list + affordances = np.array([0.0 for _ in range(len(invaffordances))]) + for k in affordances_dict.keys(): + affordances += affordances_dict[k] + affordances /= len(affordances_dict.keys()) + return affordances.tolist() + + +class AverageMeter(object): + """Computes and stores the average and current value + Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 + """ + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def log_metrics(results_dict:Dict[str, Union[np.ndarray, int, List[str]]], logger:Logger, log_path:str, + tb_writer:Optional[SummaryWriter]=None, epoch:Optional[int]=None, split:Optional[str]=None) -> None: + translation_distance_errors = results_dict['translation_distance_errors'] + translation_along_errors = results_dict['translation_along_errors'] + translation_perp_errors = results_dict['translation_perp_errors'] + translation_plane_errors = results_dict['translation_plane_errors'] + translation_line_errors = results_dict['translation_line_errors'] + translation_outliers_num = results_dict['translation_outliers_num'] + rotation_errors = results_dict['rotation_errors'] + rotation_outliers_num = results_dict['rotation_outliers_num'] + if 'affordance_errors' in results_dict.keys(): + affordance_errors = results_dict['affordance_errors'] + affordance_outliers_num = results_dict['affordance_outliers_num'] + else: + affordance_errors = [0, 0] + affordance_outliers_num = 0 + + if len(translation_distance_errors.shape) == 1: + data_num = translation_distance_errors.shape[0] + if 'names' in results_dict.keys(): + names = results_dict['names'] + dataframe = pd.DataFrame({'name': names, + 'translation_distance_errors': translation_distance_errors, + 'translation_along_errors': translation_along_errors, + 'translation_perp_errors': translation_perp_errors, + 'translation_plane_errors': translation_plane_errors, + 'translation_line_errors': translation_line_errors, + 'rotation_errors': rotation_errors, + 'affordance_errors': affordance_errors}) + dataframe.to_csv(os.path.join(log_path, 'metrics.csv'), sep=',') + else: + pass + + # mean + mean_translation_distance_error = np.mean(translation_distance_errors, axis=0) + mean_translation_along_error = np.mean(translation_along_errors, axis=0) + mean_translation_perp_error = np.mean(translation_perp_errors, axis=0) + mean_translation_plane_error = np.mean(translation_plane_errors, axis=0) + mean_translation_line_error = np.mean(translation_line_errors, axis=0) + mean_rotation_error = np.mean(rotation_errors, axis=0) + mean_affordance_error = np.mean(affordance_errors, axis=0) + logger.info(f"mean_translation_distance_error = {mean_translation_distance_error}") + logger.info(f"mean_translation_along_error = {mean_translation_along_error}") + logger.info(f"mean_translation_perp_error = {mean_translation_perp_error}") + logger.info(f"mean_translation_plane_error = {mean_translation_plane_error}") + logger.info(f"mean_translation_line_error = {mean_translation_line_error}") + logger.info(f"mean_rotation_error = {mean_rotation_error}") + logger.info(f"mean_affordance_error = {mean_affordance_error}") + # median + median_translation_distance_error = np.median(translation_distance_errors, axis=0) + median_translation_along_error = np.median(translation_along_errors, axis=0) + median_translation_perp_error = np.median(translation_perp_errors, axis=0) + median_translation_plane_error = np.median(translation_plane_errors, axis=0) + median_translation_line_error = np.median(translation_line_errors, axis=0) + median_rotation_error = np.median(rotation_errors, axis=0) + median_affordance_error = np.median(affordance_errors, axis=0) + logger.info(f"median_translation_distance_error = {median_translation_distance_error}") + logger.info(f"median_translation_along_error = {median_translation_along_error}") + logger.info(f"median_translation_perp_error = {median_translation_perp_error}") + logger.info(f"median_translation_plane_error = {median_translation_plane_error}") + logger.info(f"median_translation_line_error = {median_translation_line_error}") + logger.info(f"median_rotation_error = {median_rotation_error}") + logger.info(f"median_affordance_error = {median_affordance_error}") + # max + max_translation_distance_error = np.max(translation_distance_errors, axis=0) + max_translation_along_error = np.max(translation_along_errors, axis=0) + max_translation_perp_error = np.max(translation_perp_errors, axis=0) + max_translation_plane_error = np.max(translation_plane_errors, axis=0) + max_translation_line_error = np.max(translation_line_errors, axis=0) + max_rotation_error = np.max(rotation_errors, axis=0) + max_affordance_error = np.max(affordance_errors, axis=0) + logger.info(f"max_translation_distance_error = {max_translation_distance_error}") + logger.info(f"max_translation_along_error = {max_translation_along_error}") + logger.info(f"max_translation_perp_error = {max_translation_perp_error}") + logger.info(f"max_translation_plane_error = {max_translation_plane_error}") + logger.info(f"max_translation_line_error = {max_translation_line_error}") + logger.info(f"max_rotation_error = {max_rotation_error}") + logger.info(f"max_affordance_error = {max_affordance_error}") + # min + min_translation_distance_error = np.min(translation_distance_errors, axis=0) + min_translation_along_error = np.min(translation_along_errors, axis=0) + min_translation_perp_error = np.min(translation_perp_errors, axis=0) + min_translation_plane_error = np.min(translation_plane_errors, axis=0) + min_translation_line_error = np.min(translation_line_errors, axis=0) + min_rotation_error = np.min(rotation_errors, axis=0) + min_affordance_error = np.min(affordance_errors, axis=0) + logger.info(f"min_translation_distance_error = {min_translation_distance_error}") + logger.info(f"min_translation_along_error = {min_translation_along_error}") + logger.info(f"min_translation_perp_error = {min_translation_perp_error}") + logger.info(f"min_translation_plane_error = {min_translation_plane_error}") + logger.info(f"min_translation_line_error = {min_translation_line_error}") + logger.info(f"min_rotation_error = {min_rotation_error}") + logger.info(f"min_affordance_error = {min_affordance_error}") + # std + std_translation_distance_error = np.std(translation_distance_errors, axis=0) + std_translation_along_error = np.std(translation_along_errors, axis=0) + std_translation_perp_error = np.std(translation_perp_errors, axis=0) + std_translation_plane_error = np.std(translation_plane_errors, axis=0) + std_translation_line_error = np.std(translation_line_errors, axis=0) + std_rotation_error = np.std(rotation_errors, axis=0) + std_affordance_error = np.std(affordance_errors, axis=0) + logger.info(f"std_translation_distance_error = {std_translation_distance_error}") + logger.info(f"std_translation_along_error = {std_translation_along_error}") + logger.info(f"std_translation_perp_error = {std_translation_perp_error}") + logger.info(f"std_translation_plane_error = {std_translation_plane_error}") + logger.info(f"std_translation_line_error = {std_translation_line_error}") + logger.info(f"std_rotation_error = {std_rotation_error}") + logger.info(f"std_affordance_error = {std_affordance_error}") + # outliers + translation_outliers_ratio = translation_outliers_num / data_num + rotation_outliers_ratio = rotation_outliers_num / data_num + affordance_outliers_ratio = affordance_outliers_num / data_num + logger.info(f"translation_outliers_num = {translation_outliers_num}") + logger.info(f"rotation_outliers_num = {rotation_outliers_num}") + logger.info(f"affordance_outliers_num = {affordance_outliers_num}") + logger.info(f"translation_outliers_ratio = {translation_outliers_ratio}") + logger.info(f"rotation_outliers_ratio = {rotation_outliers_ratio}") + logger.info(f"affordance_outliers_ratio = {affordance_outliers_ratio}") + + logger.info(f"data_num = {data_num}") + + if tb_writer is not None: + tb_writer.add_scalars(f'{split}/translation_distance_error', { + 'mean': mean_translation_distance_error.item(), + 'median': median_translation_distance_error.item(), + 'max': max_translation_distance_error.item(), + 'min': min_translation_distance_error.item(), + 'std': std_translation_distance_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/translation_along_error', { + 'mean': mean_translation_along_error.item(), + 'median': median_translation_along_error.item(), + 'max': max_translation_along_error.item(), + 'min': min_translation_along_error.item(), + 'std': std_translation_along_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/translation_perp_error', { + 'mean': mean_translation_perp_error.item(), + 'median': median_translation_perp_error.item(), + 'max': max_translation_perp_error.item(), + 'min': min_translation_perp_error.item(), + 'std': std_translation_perp_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/translation_plane_error', { + 'mean': mean_translation_plane_error.item(), + 'median': median_translation_plane_error.item(), + 'max': max_translation_plane_error.item(), + 'min': min_translation_plane_error.item(), + 'std': std_translation_plane_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/translation_line_error', { + 'mean': mean_translation_line_error.item(), + 'median': median_translation_line_error.item(), + 'max': max_translation_line_error.item(), + 'min': min_translation_line_error.item(), + 'std': std_translation_line_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/rotation_error', { + 'mean': mean_rotation_error.item(), + 'median': median_rotation_error.item(), + 'max': max_rotation_error.item(), + 'min': min_rotation_error.item(), + 'std': std_rotation_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/affordance_error', { + 'mean': mean_affordance_error.item(), + 'median': median_affordance_error.item(), + 'max': max_affordance_error.item(), + 'min': min_affordance_error.item(), + 'std': std_affordance_error.item() + }, epoch) + tb_writer.add_scalars(f'{split}/outliers_ratio', { + 'translation': translation_outliers_ratio, + 'rotation': rotation_outliers_ratio, + 'affordance': affordance_outliers_ratio + }, epoch) + else: + pass + elif len(translation_distance_errors.shape) == 2: + data_num = translation_distance_errors.shape[0] + joint_num = translation_distance_errors.shape[1] + if 'names' in results_dict.keys(): + names = results_dict['names'] + for j in range(joint_num): + dataframe = pd.DataFrame({'name': names, + 'translation_distance_errors': translation_distance_errors[:, j], + 'translation_along_errors': translation_along_errors[:, j], + 'translation_perp_errors': translation_perp_errors[:, j], + 'translation_plane_errors': translation_plane_errors[:, j], + 'translation_line_errors': translation_line_errors[:, j], + 'rotation_errors': rotation_errors[:, j], + 'affordance_errors': affordance_errors[:, j]}) + dataframe.to_csv(os.path.join(log_path, f'metrics{j}.csv'), sep=',') + else: + pass + + # mean + mean_translation_distance_error = np.mean(translation_distance_errors, axis=0) + mean_translation_along_error = np.mean(translation_along_errors, axis=0) + mean_translation_perp_error = np.mean(translation_perp_errors, axis=0) + mean_translation_plane_error = np.mean(translation_plane_errors, axis=0) + mean_translation_line_error = np.mean(translation_line_errors, axis=0) + mean_rotation_error = np.mean(rotation_errors, axis=0) + mean_affordance_error = np.mean(affordance_errors, axis=0) + logger.info(f"mean_translation_distance_error = {mean_translation_distance_error}") + logger.info(f"mean_translation_along_error = {mean_translation_along_error}") + logger.info(f"mean_translation_perp_error = {mean_translation_perp_error}") + logger.info(f"mean_translation_plane_error = {mean_translation_plane_error}") + logger.info(f"mean_translation_line_error = {mean_translation_line_error}") + logger.info(f"mean_rotation_error = {mean_rotation_error}") + logger.info(f"mean_affordance_error = {mean_affordance_error}") + # median + median_translation_distance_error = np.median(translation_distance_errors, axis=0) + median_translation_along_error = np.median(translation_along_errors, axis=0) + median_translation_perp_error = np.median(translation_perp_errors, axis=0) + median_translation_plane_error = np.median(translation_plane_errors, axis=0) + median_translation_line_error = np.median(translation_line_errors, axis=0) + median_rotation_error = np.median(rotation_errors, axis=0) + median_affordance_error = np.median(affordance_errors, axis=0) + logger.info(f"median_translation_distance_error = {median_translation_distance_error}") + logger.info(f"median_translation_along_error = {median_translation_along_error}") + logger.info(f"median_translation_perp_error = {median_translation_perp_error}") + logger.info(f"median_translation_plane_error = {median_translation_plane_error}") + logger.info(f"median_translation_line_error = {median_translation_line_error}") + logger.info(f"median_rotation_error = {median_rotation_error}") + logger.info(f"median_affordance_error = {median_affordance_error}") + # max + max_translation_distance_error = np.max(translation_distance_errors, axis=0) + max_translation_along_error = np.max(translation_along_errors, axis=0) + max_translation_perp_error = np.max(translation_perp_errors, axis=0) + max_translation_plane_error = np.max(translation_plane_errors, axis=0) + max_translation_line_error = np.max(translation_line_errors, axis=0) + max_rotation_error = np.max(rotation_errors, axis=0) + max_affordance_error = np.max(affordance_errors, axis=0) + logger.info(f"max_translation_distance_error = {max_translation_distance_error}") + logger.info(f"max_translation_along_error = {max_translation_along_error}") + logger.info(f"max_translation_perp_error = {max_translation_perp_error}") + logger.info(f"max_translation_plane_error = {max_translation_plane_error}") + logger.info(f"max_translation_line_error = {max_translation_line_error}") + logger.info(f"max_rotation_error = {max_rotation_error}") + logger.info(f"max_affordance_error = {max_affordance_error}") + # min + min_translation_distance_error = np.min(translation_distance_errors, axis=0) + min_translation_along_error = np.min(translation_along_errors, axis=0) + min_translation_perp_error = np.min(translation_perp_errors, axis=0) + min_translation_plane_error = np.min(translation_plane_errors, axis=0) + min_translation_line_error = np.min(translation_line_errors, axis=0) + min_rotation_error = np.min(rotation_errors, axis=0) + min_affordance_error = np.min(affordance_errors, axis=0) + logger.info(f"min_translation_distance_error = {min_translation_distance_error}") + logger.info(f"min_translation_along_error = {min_translation_along_error}") + logger.info(f"min_translation_perp_error = {min_translation_perp_error}") + logger.info(f"min_translation_plane_error = {min_translation_plane_error}") + logger.info(f"min_translation_line_error = {min_translation_line_error}") + logger.info(f"min_rotation_error = {min_rotation_error}") + logger.info(f"min_affordance_error = {min_affordance_error}") + # std + std_translation_distance_error = np.std(translation_distance_errors, axis=0) + std_translation_along_error = np.std(translation_along_errors, axis=0) + std_translation_perp_error = np.std(translation_perp_errors, axis=0) + std_translation_plane_error = np.std(translation_plane_errors, axis=0) + std_translation_line_error = np.std(translation_line_errors, axis=0) + std_rotation_error = np.std(rotation_errors, axis=0) + std_affordance_error = np.std(affordance_errors, axis=0) + logger.info(f"std_translation_distance_error = {std_translation_distance_error}") + logger.info(f"std_translation_along_error = {std_translation_along_error}") + logger.info(f"std_translation_perp_error = {std_translation_perp_error}") + logger.info(f"std_translation_plane_error = {std_translation_plane_error}") + logger.info(f"std_translation_line_error = {std_translation_line_error}") + logger.info(f"std_rotation_error = {std_rotation_error}") + logger.info(f"std_affordance_error = {std_affordance_error}") + # outliers + translation_outliers_ratio = translation_outliers_num / (data_num * joint_num) + rotation_outliers_ratio = rotation_outliers_num / (data_num * joint_num) + affordance_outliers_ratio = affordance_outliers_num / (data_num * joint_num) + logger.info(f"translation_outliers_num = {translation_outliers_num}") + logger.info(f"rotation_outliers_num = {rotation_outliers_num}") + logger.info(f"affordance_outliers_num = {affordance_outliers_num}") + logger.info(f"translation_outliers_ratio = {translation_outliers_ratio}") + logger.info(f"rotation_outliers_ratio = {rotation_outliers_ratio}") + logger.info(f"affordance_outliers_ratio = {affordance_outliers_ratio}") + + logger.info(f"data_num = {data_num}, joint_num = {joint_num}") + + if tb_writer is not None: + for j in range(joint_num): + tb_writer.add_scalars(f'{split}/joint_{j}/translation_distance_error', { + 'mean': mean_translation_distance_error[j], + 'median': median_translation_distance_error[j], + 'max': max_translation_distance_error[j], + 'min': min_translation_distance_error[j], + 'std': std_translation_distance_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/joint_{j}/translation_along_error', { + 'mean': mean_translation_along_error[j], + 'median': median_translation_along_error[j], + 'max': max_translation_along_error[j], + 'min': min_translation_along_error[j], + 'std': std_translation_along_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/joint_{j}/translation_perp_error', { + 'mean': mean_translation_perp_error[j], + 'median': median_translation_perp_error[j], + 'max': max_translation_perp_error[j], + 'min': min_translation_perp_error[j], + 'std': std_translation_perp_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/joint_{j}/translation_plane_error', { + 'mean': mean_translation_plane_error[j], + 'median': median_translation_plane_error[j], + 'max': max_translation_plane_error[j], + 'min': min_translation_plane_error[j], + 'std': std_translation_plane_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/joint_{j}/translation_line_error', { + 'mean': mean_translation_line_error[j], + 'median': median_translation_line_error[j], + 'max': max_translation_line_error[j], + 'min': min_translation_line_error[j], + 'std': std_translation_line_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/joint_{j}/rotation_error', { + 'mean': mean_rotation_error[j], + 'median': median_rotation_error[j], + 'max': max_rotation_error[j], + 'min': min_rotation_error[j], + 'std': std_rotation_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/joint_{j}/affordance_error', { + 'mean': mean_affordance_error[j], + 'median': median_affordance_error[j], + 'max': max_affordance_error[j], + 'min': min_affordance_error[j], + 'std': std_affordance_error[j] + }, epoch) + tb_writer.add_scalars(f'{split}/outliers_ratio', { + 'translation': translation_outliers_ratio, + 'rotation': rotation_outliers_ratio, + 'affordance': affordance_outliers_ratio + }, epoch) + else: + pass + else: + raise ValueError(f"Invalid shape of translation_distance_errors: {translation_distance_errors.shape}") diff --git a/utilities/model_utils.py b/utilities/model_utils.py new file mode 100644 index 0000000..97dcd3f --- /dev/null +++ b/utilities/model_utils.py @@ -0,0 +1,4 @@ +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True diff --git a/utilities/network_utils.py b/utilities/network_utils.py new file mode 100644 index 0000000..6e7a175 --- /dev/null +++ b/utilities/network_utils.py @@ -0,0 +1,28 @@ +from typing import Optional +from paramiko import SSHClient +from scp import SCPClient + + +def send(local_path:str, remote_path:str, + remote_ip:str, port:int=22, username:Optional[str]=None, password:Optional[str]=None, key_filename:Optional[str]=None) -> None: + ssh = SSHClient() + ssh.load_system_host_keys() + ssh.connect(remote_ip, port=port, username=username, password=password, key_filename=key_filename) + + scp = SCPClient(ssh.get_transport()) + + scp.put(local_path, remote_path) + + scp.close() + +def read(local_path:str, remote_path:str, + remote_ip:str, port:int=22, username:Optional[str]=None, password:Optional[str]=None, key_filename:Optional[str]=None) -> None: + ssh = SSHClient() + ssh.load_system_host_keys() + ssh.connect(remote_ip, port=port, username=username, password=password, key_filename=key_filename) + + scp = SCPClient(ssh.get_transport()) + + scp.get(remote_path, local_path) + + scp.close() diff --git a/utilities/vis_utils.py b/utilities/vis_utils.py new file mode 100644 index 0000000..d51f08d --- /dev/null +++ b/utilities/vis_utils.py @@ -0,0 +1,481 @@ +from typing import Tuple, Union, Optional +import numpy as np +import open3d as o3d +import matplotlib.cm as cm + +from .constants import EPS + + +def visualize(pc:np.ndarray, pc_color:Optional[np.ndarray]=None, pc_normal:Optional[np.ndarray]=None, + joint_translations:Optional[np.ndarray]=None, joint_rotations:Optional[np.ndarray]=None, affordable_positions:Optional[np.ndarray]=None, + joint_axis_colors:Optional[np.ndarray]=None, joint_point_colors:Optional[np.ndarray]=None, affordable_position_colors:Optional[np.ndarray]=None, + grasp_poses:Optional[np.ndarray]=None, grasp_widths:Optional[np.ndarray]=None, grasp_depths:Optional[np.ndarray]=None, grasp_affordances:Optional[np.ndarray]=None, + whether_frame:bool=True, whether_bbox:Union[bool, Tuple[np.ndarray, np.ndarray]]=True, window_name:str='visualization') -> None: + geometries = [] + + if whether_frame: + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + geometries.append(frame) + else: + pass + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + if pc_color is None: + pass + elif len(pc_color.shape) == 1: + pcd.paint_uniform_color(pc_color) + elif len(pc_color.shape) == 2: + pcd.colors = o3d.utility.Vector3dVector(pc_color) + else: + raise NotImplementedError + if pc_normal is None: + pass + else: + pcd.normals = o3d.utility.Vector3dVector(pc_normal) + geometries.append(pcd) + + if isinstance(whether_bbox, bool) and whether_bbox: + bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + elif isinstance(whether_bbox, tuple): + bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=whether_bbox[0], max_bound=whether_bbox[1]) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + else: + pass + + joint_num = joint_translations.shape[0] if joint_translations is not None else 0 + for j in range(joint_num): + joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1) + rotation = np.zeros((3, 3)) + temp2 = np.cross(joint_rotations[j], np.array([1., 0., 0.])) + if np.linalg.norm(temp2) < EPS: + temp1 = np.cross(np.array([0., 1., 0.]), joint_rotations[j]) + temp1 /= np.linalg.norm(temp1) + temp2 = np.cross(joint_rotations[j], temp1) + temp2 /= np.linalg.norm(temp2) + else: + temp2 /= np.linalg.norm(temp2) + temp1 = np.cross(temp2, joint_rotations[j]) + temp1 /= np.linalg.norm(temp1) + rotation[:, 0] = temp1 + rotation[:, 1] = temp2 + rotation[:, 2] = joint_rotations[j] + joint_axis.rotate(rotation, np.array([[0], [0], [0]])) + joint_axis.translate(joint_translations[j].reshape((3, 1))) + if joint_axis_colors is None: + pass + elif len(joint_axis_colors.shape) == 1: + joint_axis.paint_uniform_color(joint_axis_colors) + elif len(joint_axis_colors.shape) == 2: + joint_axis.paint_uniform_color(joint_axis_colors[j]) + else: + raise NotImplementedError + geometries.append(joint_axis) + joint_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.015) + joint_point = joint_point.translate(joint_translations[j].reshape((3, 1))) + if joint_point_colors is None: + pass + elif len(joint_point_colors.shape) == 1: + joint_point.paint_uniform_color(joint_point_colors) + elif len(joint_point_colors.shape) == 2: + joint_point.paint_uniform_color(joint_point_colors[j]) + else: + raise NotImplementedError + geometries.append(joint_point) + + if affordable_positions is not None: + affordable_position_num = affordable_positions.shape[0] + for i in range(affordable_position_num): + affordable_position = o3d.geometry.TriangleMesh.create_sphere(radius=0.015) + affordable_position = affordable_position.translate(affordable_positions[i].reshape((3, 1))) + if affordable_position_colors is None: + pass + elif len(affordable_position_colors.shape) == 1: + affordable_position.paint_uniform_color(affordable_position_colors) + elif len(affordable_position_colors.shape) == 2: + affordable_position.paint_uniform_color(affordable_position_colors[i]) + else: + raise NotImplementedError + geometries.append(affordable_position) + + if grasp_poses is not None: + grasp_num = grasp_poses.shape[0] + for g in range(grasp_num): + grasp_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) + grasp_frame.transform(grasp_poses[g]) + geometries.append(grasp_frame) + + finger_width = 0.004 + tail_length = 0.04 + depth_base = 0.02 + gg_width = grasp_widths[g] + gg_depth = grasp_depths[g] + gg_affordance = grasp_affordances[g] + gg_translation = grasp_poses[g][:3, 3] + gg_rotation = grasp_poses[g][:3, :3] + + left = np.zeros((2, 3)) + left[0] = np.array([-depth_base - finger_width, -gg_width / 2, 0]) + left[1] = np.array([gg_depth, -gg_width / 2, 0]) + + right = np.zeros((2, 3)) + right[0] = np.array([-depth_base - finger_width, gg_width / 2, 0]) + right[1] = np.array([gg_depth, gg_width / 2, 0]) + + bottom = np.zeros((2, 3)) + bottom[0] = np.array([-finger_width - depth_base, -gg_width / 2, 0]) + bottom[1] = np.array([-finger_width - depth_base, gg_width / 2, 0]) + + tail = np.zeros((2, 3)) + tail[0] = np.array([-(tail_length + finger_width + depth_base), 0, 0]) + tail[1] = np.array([-(finger_width + depth_base), 0, 0]) + + vertices = np.vstack([left, right, bottom, tail]) + vertices = np.dot(gg_rotation, vertices.T).T + gg_translation + + line_set = o3d.geometry.LineSet() + line_set.points = o3d.utility.Vector3dVector(vertices) + line_set.lines = o3d.utility.Vector2iVector([[0, 1], [2, 3], [4, 5], [6, 7]]) + if gg_affordance < 0.5: + line_set.paint_uniform_color([1, 2*gg_affordance, 0]) + elif gg_affordance == 1.0: + line_set.paint_uniform_color([0., 0., 1.]) + else: + line_set.paint_uniform_color([-2*gg_affordance+2, 1, 0]) + geometries.append(line_set) + + o3d.visualization.draw_geometries(geometries, point_show_normal=pc_normal is not None, window_name=window_name) + +def visualize_mask(pc:np.ndarray, instance_mask:np.ndarray, function_mask:np.ndarray, + pc_normal:Optional[np.ndarray]=None, + joint_translations:Optional[np.ndarray]=None, joint_rotations:Optional[np.ndarray]=None, affordable_positions:Optional[np.ndarray]=None, + whether_frame:bool=True, whether_bbox:Union[bool, Tuple[np.ndarray, np.ndarray]]=True, window_name:str='visualization') -> None: + geometries = [] + + if whether_frame: + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + geometries.append(frame) + else: + pass + + instance_ids = np.unique(instance_mask) # [0, J] + functions = [] + for instance_id in instance_ids: + instance_function = function_mask[instance_mask == instance_id] + functions.append(np.unique(instance_function)[0]) + revolute_num, prismatic_num, fixed_num = 0, 0, 0 + for f in functions: + if f == 0: + revolute_num += 1 + elif f == 1: + prismatic_num += 1 + elif f == 2: + fixed_num += 1 + else: + raise ValueError(f"Invalid function {f}") + # assert fixed_num == 1 + revolute_gradient = 1.0 / revolute_num if revolute_num > 0 else 0.0 + prismatic_gradient = 1.0 / prismatic_num if prismatic_num > 0 else 0.0 + + pc_color = np.zeros((pc.shape[0], 3), dtype=np.float32) + revolute_num, prismatic_num = 0, 0 + for instance_idx, instance_id in enumerate(instance_ids): + if functions[instance_idx] == 0: + pc_color[instance_mask == instance_id] = np.array([1. - revolute_gradient * revolute_num, 0., 0.]) + revolute_num += 1 + elif functions[instance_idx] == 1: + pc_color[instance_mask == instance_id] = np.array([0., 1. - prismatic_gradient * prismatic_num, 0.]) + prismatic_num += 1 + elif functions[instance_idx] == 2: + pc_color[instance_mask == instance_id] = np.array([0., 0., 0.]) + else: + raise ValueError(f"Invalid function {functions[instance_idx]}") + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + pcd.colors = o3d.utility.Vector3dVector(pc_color) + if pc_normal is None: + pass + else: + pcd.normals = o3d.utility.Vector3dVector(pc_normal) + geometries.append(pcd) + + if isinstance(whether_bbox, bool) and whether_bbox: + bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + elif isinstance(whether_bbox, tuple): + bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=whether_bbox[0], max_bound=whether_bbox[1]) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + else: + pass + + joint_num = joint_translations.shape[0] if joint_translations is not None else 0 + revolute_num, prismatic_num = 0, 0 + for j in range(joint_num): + joint_function = functions[j+1] + if joint_function == 0: + joint_color = np.array([1. - revolute_gradient * revolute_num, 0., 0.]) + revolute_num += 1 + elif joint_function == 1: + joint_color = np.array([0., 1. - prismatic_gradient * prismatic_num, 0.]) + prismatic_num += 1 + else: + raise ValueError(f"Invalid function {joint_function}") + joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1) + rotation = np.zeros((3, 3)) + temp2 = np.cross(joint_rotations[j], np.array([1., 0., 0.])) + if np.linalg.norm(temp2) < EPS: + temp1 = np.cross(np.array([0., 1., 0.]), joint_rotations[j]) + temp1 /= np.linalg.norm(temp1) + temp2 = np.cross(joint_rotations[j], temp1) + temp2 /= np.linalg.norm(temp2) + else: + temp2 /= np.linalg.norm(temp2) + temp1 = np.cross(temp2, joint_rotations[j]) + temp1 /= np.linalg.norm(temp1) + rotation[:, 0] = temp1 + rotation[:, 1] = temp2 + rotation[:, 2] = joint_rotations[j] + joint_axis.rotate(rotation, np.array([[0], [0], [0]])) + joint_axis.translate(joint_translations[j].reshape((3, 1))) + joint_axis.paint_uniform_color(joint_color) + geometries.append(joint_axis) + joint_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.015) + joint_point = joint_point.translate(joint_translations[j].reshape((3, 1))) + joint_point.paint_uniform_color(joint_color) + geometries.append(joint_point) + if affordable_positions is not None: + affordance_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.015) + affordance_point = affordance_point.translate(affordable_positions[j].reshape((3, 1))) + affordance_point.paint_uniform_color(joint_color) + geometries.append(affordance_point) + + o3d.visualization.draw_geometries(geometries, point_show_normal=pc_normal is not None, window_name=window_name) + + +def visualize_translation_voting(grid_pc:np.ndarray, votes_list:np.ndarray, + pc:Optional[np.ndarray]=None, pc_color:Optional[np.ndarray]=None, + gt_translation:Optional[np.ndarray]=None, gt_color:Optional[np.ndarray]=None, + pred_translation:Optional[np.ndarray]=None, pred_color:Optional[np.ndarray]=None, + show_threshold:float=0.5, whether_frame:bool=True, whether_bbox:bool=True, + window_name:str='visualization') -> None: + geometries = [] + + if whether_frame: + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + geometries.append(frame) + else: + pass + + grid_pcd = o3d.geometry.PointCloud() + votes_list = votes_list / np.max(votes_list) + grid_pc_color = np.zeros((grid_pc.shape[0], 3)) + grid_pc_color = np.stack([np.ones_like(votes_list), 1-votes_list, 1-votes_list], axis=-1) + grid_pcd.points = o3d.utility.Vector3dVector(grid_pc[votes_list >= show_threshold]) + grid_pcd.colors = o3d.utility.Vector3dVector(grid_pc_color[votes_list >= show_threshold]) + geometries.append(grid_pcd) + print(grid_pc[votes_list >= show_threshold].shape[0]) + + if whether_bbox: + grid_bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(o3d.utility.Vector3dVector(grid_pc)) + grid_bbox.color = np.array([0, 1, 0]) + geometries.append(grid_bbox) + else: + pass + + if gt_translation is None: + pass + else: + gt_point = o3d.geometry.PointCloud() + gt_point.points = o3d.utility.Vector3dVector(gt_translation[None, :]) + if gt_color is None: + pass + else: + gt_point.paint_uniform_color(gt_color) + geometries.append(gt_point) + + if pred_translation is None: + pass + else: + pred_point = o3d.geometry.PointCloud() + pred_point.points = o3d.utility.Vector3dVector(pred_translation[None, :]) + if pred_color is None: + pass + else: + pred_point.paint_uniform_color(pred_color) + geometries.append(pred_point) + + if pc is None: + pass + else: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + if pc_color is None: + pass + elif len(pc_color.shape) == 1: + pcd.paint_uniform_color(pc_color) + elif len(pc_color.shape) == 2: + pcd.colors = o3d.utility.Vector3dVector(pc_color) + else: + raise NotImplementedError + geometries.append(pcd) + + if whether_bbox: + bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + else: + pass + + o3d.visualization.draw_geometries(geometries, window_name=window_name) + +def visualize_rotation_voting(sphere_pts:np.ndarray, votes:np.ndarray, + pc:Optional[np.ndarray]=None, pc_color:Optional[np.ndarray]=None, + gt_rotation:Optional[np.ndarray]=None, gt_color:Optional[np.ndarray]=None, + pred_rotation:Optional[np.ndarray]=None, pred_color:Optional[np.ndarray]=None, + show_threshold:float=0.5, whether_frame:bool=True, whether_bbox:bool=True, + window_name:str='visualization') -> None: + geometries = [] + + if whether_frame: + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + geometries.append(frame) + else: + pass + + joint_num = sphere_pts.shape[0] + votes = votes / np.max(votes) + print(votes[votes >= show_threshold].shape[0]) + for j in range(joint_num): + if votes[j] < show_threshold: + continue + joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1) + rotation = np.zeros((3, 3)) + temp2 = np.cross(sphere_pts[j], np.array([1., 0., 0.])) + if np.linalg.norm(temp2) < EPS: + temp1 = np.cross(np.array([0., 1., 0.]), sphere_pts[j]) + temp1 /= np.linalg.norm(temp1) + temp2 = np.cross(sphere_pts[j], temp1) + temp2 /= np.linalg.norm(temp2) + else: + temp2 /= np.linalg.norm(temp2) + temp1 = np.cross(temp2, sphere_pts[j]) + temp1 /= np.linalg.norm(temp1) + rotation[:, 0] = temp1 + rotation[:, 1] = temp2 + rotation[:, 2] = sphere_pts[j] + joint_axis.rotate(rotation, np.array([[0], [0], [0]])) + joint_axis.paint_uniform_color(np.array([1, 1-votes[j], 1-votes[j]])) + geometries.append(joint_axis) + + if gt_rotation is None: + pass + else: + joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1) + rotation = np.zeros((3, 3)) + temp2 = np.cross(gt_rotation, np.array([1., 0., 0.])) + if np.linalg.norm(temp2) < EPS: + temp1 = np.cross(np.array([0., 1., 0.]), gt_rotation) + temp1 /= np.linalg.norm(temp1) + temp2 = np.cross(gt_rotation, temp1) + temp2 /= np.linalg.norm(temp2) + else: + temp2 /= np.linalg.norm(temp2) + temp1 = np.cross(temp2, gt_rotation) + temp1 /= np.linalg.norm(temp1) + rotation[:, 0] = temp1 + rotation[:, 1] = temp2 + rotation[:, 2] = gt_rotation + joint_axis.rotate(rotation, np.array([[0], [0], [0]])) + if gt_color is None: + pass + else: + joint_axis.paint_uniform_color(gt_color) + geometries.append(joint_axis) + + if pred_rotation is None: + pass + else: + joint_axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.01, cone_radius=0.02, cylinder_height=0.4, cone_height=0.1, resolution=20, cylinder_split=4, cone_split=1) + rotation = np.zeros((3, 3)) + temp2 = np.cross(pred_rotation, np.array([1., 0., 0.])) + if np.linalg.norm(temp2) < EPS: + temp1 = np.cross(np.array([0., 1., 0.]), pred_rotation) + temp1 /= np.linalg.norm(temp1) + temp2 = np.cross(pred_rotation, temp1) + temp2 /= np.linalg.norm(temp2) + else: + temp2 /= np.linalg.norm(temp2) + temp1 = np.cross(temp2, pred_rotation) + temp1 /= np.linalg.norm(temp1) + rotation[:, 0] = temp1 + rotation[:, 1] = temp2 + rotation[:, 2] = pred_rotation + joint_axis.rotate(rotation, np.array([[0], [0], [0]])) + if pred_color is None: + pass + else: + joint_axis.paint_uniform_color(pred_color) + geometries.append(joint_axis) + + if pc is None: + pass + else: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + if pc_color is None: + pass + elif len(pc_color.shape) == 1: + pcd.paint_uniform_color(pc_color) + elif len(pc_color.shape) == 2: + pcd.colors = o3d.utility.Vector3dVector(pc_color) + else: + raise NotImplementedError + geometries.append(pcd) + + if whether_bbox: + bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + else: + pass + + o3d.visualization.draw_geometries(geometries, window_name=window_name) + + +def visualize_confidence_voting(confs:np.ndarray, pc:np.ndarray, point_idxs:np.ndarray, + whether_frame:bool=True, whether_bbox:bool=True, window_name:str='visualization') -> None: + # confs: (N_t,), pc: (N, 3), point_idxs: (N_t, 2) + point_heats = np.zeros((pc.shape[0],)) # (N,) + for i in range(confs.shape[0]): + point_heats[point_idxs[i]] += confs[i] + print(np.max(point_heats), np.min(point_heats[point_heats > 0]), point_heats[point_heats > 0].shape[0]) + point_heats /= np.max(point_heats) + + geometries = [] + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc) + cmap = cm.get_cmap('jet') + pcd.colors = o3d.utility.Vector3dVector(cmap(point_heats)[:, :3]) + geometries.append(pcd) + + if whether_frame: + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + geometries.append(frame) + else: + pass + + if whether_bbox: + bbox = o3d.geometry.AxisAlignedBoundingBox.create_from_points(pcd.points) + bbox.color = np.array([1, 0, 0]) + geometries.append(bbox) + else: + pass + + o3d.visualization.draw_geometries(geometries, window_name=window_name) diff --git a/weights/.gitkeep b/weights/.gitkeep new file mode 100644 index 0000000..e69de29