Skip to content

Commit

Permalink
MPC working w/o exec but model too inaccurate
Browse files Browse the repository at this point in the history
  • Loading branch information
hbuurmei committed Jan 30, 2025
1 parent 16e94ef commit a6b6daa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 80 deletions.
46 changes: 0 additions & 46 deletions stack/main/scripts/data_visualization.ipynb

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions stack/main/src/controller/controller/mpc_solver_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,18 @@ def gusto_callback(self, request, response):
t, xopt, uopt, zopt
"""
t0 = request.t0
x0 = arr2jnp(request.x0, self.model.n_x, squeeze=True)
y0 = arr2jnp(request.y0, self.model.n_y, squeeze=True)
x0 = self.model.encode(y0)

# Get target values at proper times by interpolating
z, zf, u = self.get_target(t0)

# Get initial guess
idx0 = jnp.argwhere(self.topt >= t0)[0, 0]
u_init = self.uopt[-1, :].reshape(1, -1).repeat(self.N, axis=0)
u_init[0:self.N - idx0] = self.uopt[idx0:, :]
x_init = self.xopt[-1, :].reshape(1, -1).repeat(self.N + 1, axis=0)
x_init[0:self.N + 1 - idx0] = self.xopt[idx0:, :]
u_init = jnp.tile(self.uopt[-1, :].reshape(1, -1), (self.N, 1))
u_init = u_init.at[:self.N - idx0].set(self.uopt[idx0:, :])
x_init = jnp.tile(self.xopt[-1, :].reshape(1, -1), (self.N + 1, 1))
x_init = x_init.at[:self.N + 1 - idx0].set(self.xopt[idx0:, :])

# Solve GuSTO and get solution
self.gusto.solve(x0, u_init, x_init, z=z, zf=zf, u=u)
Expand Down
45 changes: 24 additions & 21 deletions stack/main/src/executor/executor/experiment_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
logging.getLogger('jax').setLevel(logging.ERROR)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from rclpy.qos import QoSProfile # type: ignore
Expand Down Expand Up @@ -76,10 +77,10 @@ def __init__(self):
# Maintain current observations because of the delay embedding
self.y = None

self.get_logger().info('Run experiment node has been started.')

# Keep a clock for timing
self.clock = self.get_clock()
# self.start_time = self.clock.now().nanoseconds / 1e9

self.get_logger().info('Run experiment node has been started.')

def mocap_listener_callback(self, msg):
if self.debug:
Expand All @@ -100,20 +101,22 @@ def mocap_listener_callback(self, msg):
else:
# At initialization use current obs. as delay embedding
self.y = jnp.tile(y_centered_midtip, 2)
# And record starting time
self.start_time = self.clock.now().nanoseconds / 1e9

t0 = self.clock.now().nanoseconds / 1e9
x0 = self.model.encode(self.y)
t0 = self.clock.now().nanoseconds / 1e9 - self.start_time

# Call the service
self.mpc_client.send_request(t0, x0, wait=False)
self.mpc_client.future.add_done_callback(self.service_callback)
self.send_request(t0, self.y, wait=False)
self.future.add_done_callback(self.service_callback)

def service_callback(self, async_response):
try:
response = async_response.result()
# TODO: enable control execution (for now just print what would be commanded)
# self.publish_control_inputs(response.uopt)
self.get_logger().info(f'We would command the control inputs: {response}.')
# TODO: check if this uopt needs formatting or not (eg use get_solution func.)
self.get_logger().info(f'We would command the control inputs: {response.uopt}.')
except Exception as e:
self.get_logger().error(f'Service call failed: {e}.')

Expand All @@ -126,29 +129,29 @@ def publish_control_inputs(self, control_inputs):
if self.debug:
self.get_logger().info(f'Published new motor control setting: {control_inputs}.')

def send_request(self, t0, x0, wait=False):
def send_request(self, t0, y0, wait=False):
"""
Send request to MPC solver.
"""
self.req.t0 = t0
self.req.x0 = jnp2arr(x0)
self.req.y0 = jnp2arr(y0)
self.future = self.mpc_client.call_async(self.req)

if wait:
# Synchronous call, not compatible for real-time applications
rclpy.spin_until_future_complete(self, self.future)

def get_solution(self, n_x, n_u):
"""
Obtain result from MPC solver.
"""
res = self.future.result()
t = arr2jnp(res.t, 1, squeeze=True)
xopt = arr2jnp(res.xopt, n_x)
uopt = arr2jnp(res.uopt, n_u)
t_solve = res.solve_time

return t, uopt, xopt, t_solve
# def get_solution(self, n_x, n_u):
# """
# Obtain result from MPC solver.
# """
# res = self.future.result()
# t = arr2jnp(res.t, 1, squeeze=True)
# xopt = arr2jnp(res.xopt, n_x)
# uopt = arr2jnp(res.uopt, n_u)
# t_solve = res.solve_time

# return t, uopt, xopt, t_solve

def force_spin(self):
if not self.check_if_done():
Expand Down
1 change: 1 addition & 0 deletions stack/main/src/executor/executor/mpc_initializer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
logging.getLogger('jax').setLevel(logging.ERROR)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
import rclpy # type: ignore
from rclpy.node import Node # type: ignore
from controller.mpc.gusto import GuSTOConfig # type: ignore
Expand Down
13 changes: 5 additions & 8 deletions stack/main/src/interfaces/srv/ControlSolver.srv
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# Dimensions
int8 horizon
int8 n_u
int8 n_x
int8 n_z

# Time (seconds) associated with x0
float64 t0

# Initial condition array
# Initial state array
float64[] x0

# Initial observation array
float64[] y0

# Initial guess for optimal u for GuSTO
float64[] u_init

Expand Down Expand Up @@ -38,4 +35,4 @@ float64[] uopt
float64[] zopt

# solve time
float64 solve_time
float64 solve_time

0 comments on commit a6b6daa

Please sign in to comment.