Skip to content

Commit

Permalink
Sample 2D rotations when given axis of rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonPi committed Mar 22, 2024
1 parent dbb641a commit 11de9ad
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/pytorch_kinematics/transforms/perturbation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from pytorch_kinematics.transforms.rotation_conversions import axis_angle_to_matrix
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33


def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma):
def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma, axis_of_rotation=None,
translation_perpendicular_to_axis_of_rotation=True):
"""
Sample perturbations around the given transform. The translation and rotation are sampled independently from
0 mean gaussians. The angular perturbations' directions are uniformly sampled from the unit sphere while its
Expand All @@ -11,18 +12,33 @@ def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma):
:param num_perturbations: number of perturbations to sample
:param radian_sigma: standard deviation of the gaussian angular perturbation in radians
:param translation_sigma: standard deviation of the gaussian translation perturbation in meters / T units
:param axis_of_rotation: if not None, the axis of rotation to sample the perturbations around
:param translation_perpendicular_to_axis_of_rotation: if True and the axis_of_rotation is not None, the translation
perturbations will be perpendicular to the axis of rotation
:return: perturbed transforms; may not include the original transform
"""
dtype = T.dtype
device = T.device
perturbed = torch.eye(4, dtype=dtype, device=device).repeat(num_perturbations, 1, 1)

delta_R = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * radian_sigma
delta_R = axis_angle_to_matrix(delta_R)
delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma
# consider sampling from the Bingham distribution
theta = torch.randn(num_perturbations, dtype=dtype, device=device) * radian_sigma
if axis_of_rotation is not None:
axis_angle = axis_of_rotation
# sample translation perturbation perpendicular to the axis of rotation
# remove the component of delta_t along the axis_of_rotation
if translation_perpendicular_to_axis_of_rotation:
delta_t -= (delta_t * axis_of_rotation).sum(dim=1, keepdim=True) * axis_of_rotation
else:
axis_angle = torch.randn((num_perturbations, 3), dtype=dtype, device=device)
# normalize to unit length
axis_angle = axis_angle / axis_angle.norm(dim=1, keepdim=True)

delta_R = axis_and_angle_to_matrix_33(axis_angle, theta)
perturbed[:, :3, :3] = delta_R @ T[..., :3, :3]
perturbed[:, :3, 3] = T[..., :3, 3]

delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma
perturbed[:, :3, 3] += delta_t

return perturbed

0 comments on commit 11de9ad

Please sign in to comment.