Skip to content

Commit

Permalink
Rotation unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
hannorein committed Jan 2, 2025
1 parent b0892b2 commit 98bccf2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 10 deletions.
56 changes: 47 additions & 9 deletions rebound/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class Rotation(ctypes.Structure):
This class facilitates rotations of Vec3d objects, and provides various convenience functions
for commonly used rotations in celestial mechanics.
"""
def __init__(self, ix=None, iy=None, iz=None, r=None, angle=None, axis=None):
def __init__(self, ix=None, iy=None, iz=None, r=None, angle=None, axis=None, fromv=None, tov=None):
"""
Rotations are implemented as quaternions r + (ix)i + (iy)j + (iz)k. To initialize one
can directly pass a set of the real numbers (ix, iy, iz, r). Alternatively one can pass
Expand All @@ -29,22 +29,34 @@ def __init__(self, ix=None, iy=None, iz=None, r=None, angle=None, axis=None):
shorthand for [1,0,0], [0,1,0], [0,0,1].
"""
cart = [ix, iy, iz, r]
angle_axis = [angle, axis]
if cart.count(None) == len(cart) and angle_axis.count(None) == len(angle_axis):
super(Rotation, self).__init__(0.0,0.0,0.0,1.0) # Identity
angleaxis = [angle, axis]
fromto = [fromv, tov]
supplied = [a.count(None)!=len(a) for a in [cart, angleaxis, fromto]]
if sum(supplied) > 1:
raise ValueError("Cannot mix parameters.")
if cart.count(None) == len(cart) and angleaxis.count(None) == len(angleaxis) and fromto.count(None) == len(fromto):
clibrebound.reb_rotation_identity.restype = Rotation
q = clibrebound.reb_rotation_identity()
super(Rotation, self).__init__(q.ix, q.iy, q.iz, q.r)
return
if cart.count(None) != 0 and angle_axis.count(None) == len(angle_axis):
if cart.count(None) != 0 and cart.count(None) != len(cart):
raise ValueError("You need to specify all four parameters ix, iy, iz, r.")
if angle_axis.count(None) != 0 and cart.count(None) == len(cart):
if angleaxis.count(None) != 0 and angleaxis.count(None) != len(angleaxis):
raise ValueError("You need to specify both angle and axis.")
if cart.count(None) < len(cart) and angle_axis.count(None) < len(angle_axis):
raise ValueError("Cannot mix parameters ix, iy, iz, r with angle, axis.")
if fromto.count(None) != 0 and fromto.count(None) != len(fromto):
raise ValueError("You need to specify both fromv and tov.")
if cart.count(None) == 0:
super(Rotation, self).__init__(ix, iy, iz, r)
if angle_axis.count(None) == 0:
return
if angleaxis.count(None) == 0:
clibrebound.reb_rotation_init_angle_axis.restype = Rotation
q = clibrebound.reb_rotation_init_angle_axis(ctypes.c_double(angle), Vec3d(axis)._vec3d)
super(Rotation, self).__init__(q.ix, q.iy, q.iz, q.r)
return
if fromto.count(None) == 0:
q = Rotation.from_to(fromv, tov)
super(Rotation, self).__init__(q.ix, q.iy, q.iz, q.r)
return

@classmethod
def from_to(cls, fromv, tov):
Expand Down Expand Up @@ -144,6 +156,19 @@ def to_new_axes(cls, newz, newx=None):
clibrebound.reb_rotation_init_to_new_axes.restype = cls
q = clibrebound.reb_rotation_init_to_new_axes(Vec3d(newz)._vec3d, Vec3d(newx)._vec3d)
return q

def orbital(self):
"""
Returns a three vector with orbital elements Omega, inc, omega.
Note: the angles might not always be in the correct quadrant and might be
inconsistent with REBOUND's standard definition of orbital elements.
"""
Omega = ctypes.c_double()
inc = ctypes.c_double()
omega = ctypes.c_double()
clibrebound.reb_rotation_to_orbital(self, ctypes.byref(Omega), ctypes.byref(inc), ctypes.byref(omega))
return [Omega.value, inc.value, omega.value]


def inverse(self):
"""
Expand All @@ -153,6 +178,19 @@ def inverse(self):
q = clibrebound.reb_rotation_inverse(self)
return q

def normalize(self):
"""
Returns a normalized copy of the Rotation object.
"""
clibrebound.reb_rotation_normalize.restype = Rotation
q = clibrebound.reb_rotation_normalize(self)
return q

def __eq__(self, other):
if not isinstance(other, Rotation):
return NotImplemented
return self.ix == other.ix and self.iy == other.iy and self.iz == other.iz and self.r == other.r

def __mul__(self, other):
if isinstance(other, Rotation):
clibrebound.reb_rotation_mul.restype = Rotation
Expand Down
63 changes: 62 additions & 1 deletion rebound/tests/test_rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_from_to_edge(self):
self.assertAlmostEqual(res[1], 0, delta=1e-15)
self.assertAlmostEqual(res[2], 0, delta=1e-15)

def test_to_orbit(self):
def test_orbit(self):
sim = rebound.Simulation()
a, e, inc, Omega, omega = 1, 0.1, 0.2, 0.3, 0.4
sim.add(m=1)
Expand All @@ -45,6 +45,12 @@ def test_to_orbit(self):
self.assertAlmostEqual(res[0], sim.particles[1].x, delta=1e-15)
self.assertAlmostEqual(res[1], sim.particles[1].y, delta=1e-15)
self.assertAlmostEqual(res[2], sim.particles[1].z, delta=1e-15)

# checking inverse: Rotation -> orbital elements
_Omega, _inc, _omega = r.orbital()
self.assertAlmostEqual(_Omega, Omega, delta=2e-15)
self.assertAlmostEqual(_inc, inc, delta=2e-15)
self.assertAlmostEqual(_omega, omega, delta=2e-15)

def test_to_new_axes(self):
sim = rebound.Simulation()
Expand Down Expand Up @@ -98,6 +104,61 @@ def test_to_from_spherical(self):
self.assertAlmostEqual(mag, mag2, delta=1e-15)
self.assertAlmostEqual(theta, theta2, delta=1e-15)
self.assertAlmostEqual(phi, phi2, delta=1e-15)

def test_rotate_sim(self):
sim = rebound.Simulation()
sim.add(m=1)
sim.add(x=1)
r = rebound.Rotation(angle=math.pi/2, axis=[0,0,1])
sim = r * sim
self.assertAlmostEqual(0, sim.particles[0].x, delta=1e-15)
self.assertAlmostEqual(0, sim.particles[0].y, delta=1e-15)
self.assertAlmostEqual(0, sim.particles[0].z, delta=1e-15)
self.assertAlmostEqual(0, sim.particles[1].x, delta=1e-15)
self.assertAlmostEqual(1, sim.particles[1].y, delta=1e-15)
self.assertAlmostEqual(0, sim.particles[1].z, delta=1e-15)

def test_rotate_particle(self):
p = rebound.Particle(x=1)
r = rebound.Rotation(angle=math.pi/2, axis=[0,0,1])
p = r * p
self.assertAlmostEqual(0, p.x, delta=1e-15)
self.assertAlmostEqual(1, p.y, delta=1e-15)
self.assertAlmostEqual(0, p.z, delta=1e-15)

def test_normalize(self):
r1 = rebound.Rotation(ix=1, iy=0, iz=0, r=0)
r2 = rebound.Rotation(ix=2, iy=0, iz=0, r=0)
r3 = r2.normalize()
self.assertNotEqual(r1, r2)
self.assertNotEqual(r2, r3)
self.assertEqual(r1, r3)

def test_identity(self):
r = rebound.Rotation()
a = [1,2,3]
b = r*a
self.assertEqual(a[0], b[0])
self.assertEqual(a[1], b[1])
self.assertEqual(a[2], b[2])

def test_tofrom(self):
sim = rebound.Simulation()
a, e, inc, Omega, omega = 1, 0.1, 0.2, 0.3, 0.4
sim.add(m=1)
sim.add(a=a,e=e)
sim.add(a=a,e=e,inc=inc,Omega=Omega,omega=omega,f=0)
r = rebound.Rotation(fromv=sim.particles[1].xyz, tov=sim.particles[2].xyz)
res = r*sim.particles[1].xyz
self.assertAlmostEqual(res[0], sim.particles[2].x, delta=1e-18)
self.assertAlmostEqual(res[1], sim.particles[2].y, delta=1e-15)
self.assertAlmostEqual(res[2], sim.particles[2].z, delta=1e-15)

r = rebound.Rotation.from_to(fromv=sim.particles[1].xyz, tov=sim.particles[2].xyz)
res = r*sim.particles[1].xyz
self.assertAlmostEqual(res[0], sim.particles[2].x, delta=1e-18)
self.assertAlmostEqual(res[1], sim.particles[2].y, delta=1e-15)
self.assertAlmostEqual(res[2], sim.particles[2].z, delta=1e-15)

if __name__ == "__main__":
unittest.main()

0 comments on commit 98bccf2

Please sign in to comment.