diff --git a/README.md b/README.md
index bb2795b..d9d41ca 100644
--- a/README.md
+++ b/README.md
@@ -21,10 +21,22 @@ That means you can test out the controller in a plug-and-play manner with minimu
## News
+- 11/03/2024: 🎉 Sim2Real pipeline is ready! Check out the [Sim2Real](#deploy-in-real-unitree-go2) section for more details.
- 09/25/2024: 🎉 DIAL-MPC is released with open-source codes! Sim2Real pipeline coming soon!
https://github.com/user-attachments/assets/f2e5f26d-69ac-4478-872e-26943821a218
+
+## Table of Contents
+
+1. [Install](#install-dial-mpc)
+2. [Synchronous Simulation](#synchronous-simulation)
+3. [Asynchronous Simulation](#asynchronous-simulation)
+4. [Deploy in Real](#deploy-in-real-unitree-go2)
+5. [Writing Your Own Environment](#writing-custom-environment)
+6. [Rendering Rollouts](#rendering-rollouts-in-blender)
+7. [Citing this Work](#bibtex)
+
## Simulation Setup
### Install `dial-mpc`
@@ -43,6 +55,8 @@ pip3 install -e .
## Synchronous Simulation
+In this mode, the simulation will wait for DIAL-MPC to finish computing before stepping. It is ideal for debugging and doing tasks that are currently not real-time.
+
#### Run Examples
List available examples:
@@ -61,7 +75,9 @@ After rollout completes, go to `127.0.0.1:5000` to visualize the rollouts.
## Asynchronous Simulation
-The asynchronous simulation is meant to test the algorithm before Sim2Real.
+The asynchronous simulation is meant to test the algorithm before Sim2Real. The simulation rolls out in real-time (or scaled by `real_time_factor`). DIAL-MPC will encounter delay in this case.
+
+When DIAL-MPC cannot finish the compute in the time defined by `dt`, it will spit out warning. Slight overtime is accepetable as DIAL-MPC maintains a buffer of the previous step's solution and will play out the planned action sequence until the buffer runs out.
List available examples:
@@ -85,19 +101,120 @@ dial-mpc-plan --example unitree_go2_seq_jump_deploy
```
-## Deploy in Real
+## Deploy in Real (Unitree Go2)
+
+### Overview
+
+The real-world deployment procedure is very similar to asynchronous simulation.
+
+We use `unitree_sdk2_python` to communicate with the robot directly via CycloneDDS.
+
+### Step 1: State Estimation
+
+For state estimation, this proof-of-concept work requires external localization module to get base **position** and **velocity**.
+
+The following plugins are built-in:
+
+- ROS2 odometry message
+- Vicon motion capture system
+
+#### Option 1: ROS2 odometry message
+
+Configure `odom_topic` in the YAML file. You are responsible for publishing this message at at least 50 Hz and ideally over 100 Hz. We provide an odometry publisher for Vicon motion capture system in [`vicon_interface`](https://github.com/LeCAR-Lab/vicon_interface).
+
+> [!CAUTION]
+> All velocities in ROS2 odometry message **must** be in **body frame** of the base to conform to [ROS odometry message definition](https://docs.ros.org/en/noetic/api/nav_msgs/html/msg/Odometry.html), although in the end they are converted to world frame in DIAL-MPC.
+
+#### Option 2: Vicon (no ROS2 required)
+
+1. `pip install pyvicon-datastream`
+2. Change `localization_plugin` to `vicon_shm_plugin` in the YAML file.
+3. Configure `vicon_tracker_ip`, `vicon_object_name`, and `vicon_z_offset` in the YAML file.
+
+#### Option 3: Bring Your Own Plugin
+
+We provide a simple ABI for custom localization modules, and you need to implement this in a python file in your workspace, should you consider not using the built-in plugins.
+
+```python
+import numpy as np
+import time
+from dial_mpc.deploy.localization import register_plugin
+from dial_mpc.deploy.localization.base_plugin import BaseLocalizationPlugin
+
+class MyPlugin(BaseLocalizationPlugin):
+ def __init__(self, config):
+ pass
+
+ def get_state(self):
+ qpos = np.zeros(7)
+ qvel = np.zeros(6)
+ return np.concatenate([qpos, qvel])
+
+ def get_last_update_time(self):
+ return time.time()
+
+register_plugin('custom_plugin', plugin_cls=MyPlugin)
+```
+
+> [!CAUTION]
+> When writing custom localization plugin, velocities should be reported in **world frame**.
+
+> [!NOTE]
+> Angular velocity source is onboard IMU. You could leave `qvel[3:6]` in the returned state as zero for now.
+
+Localization plugin can be changed in the configuration file. A `--plugin` argument can be supplied to `dial-mpc-real` to import a custom localization plugin in the current workspace.
+
+### Step 2: Installing `unitree_sdk2_python`
+
+> [!NOTE]
+> If you are already using ROS2 with Cyclone DDS according to [ROS2 documentation on Cyclone DDS](https://docs.ros.org/en/humble/Installation/DDS-Implementations/Working-with-Eclipse-CycloneDDS.html), you don't have to install Cyclone DDS as suggested by `unitree_sdk2_python`. But do follow the rest of the instructions.
+
+Follow the instructions in [`unitree_sdk2_python`](https://github.com/unitreerobotics/unitree_sdk2_python).
+
+### Step 3: Configuring DIAL-MPC
+
+In `dial_mpc/examples/unitree_go2_trot_deploy.yaml` or `dial_mpc/examples/unitree_go2_seq_jump.yaml`, modify `network_interface` to match the name of the network interface connected to Go2.
+
+Alternatively, you can also pass `--network_interface` to `dial-mpc-real` when launching the robot, which will override the config.
+
+### Step 4: Starting the Robot
+
+Follow the [official Unitree documentation](https://support.unitree.com/home/en/developer/Quick_start) to disable sports mode on Go2. Lay the robot flat on the ground like shown.
+
+
+
+
+
+### Step 5: Running the Robot
+
+List available examples:
+
+```bash
+dial-mpc-real --list-examples
+```
+
+Run an example:
+
+In terminal 1, run
+
+```bash
+# source /opt/ros//setup.bash # if using ROS2
+dial-mpc-real --example unitree_go2_seq_jump_deploy
+```
+
+This will open a mujoco visualization window. The robot will slowly stand up. If the robot is squatting, manually lift the robot into a standing position. Verify that the robot states match the real world and are updating.
+
+You can supply additional arguments to `dial-mpc-real`:
+
+- `--custom-env`: custom environment definition.
+- `--network-interface`: override network interface configuration.
+- `--plugin`: custom localization plugin.
-🚧 Check back in late Sep. - early Oct. 2024 for real-world deployment pipeline on Unitree Go2.
-
+dial-mpc-plan --example unitree_go2_seq_jump_deploy
+```
## Writing Custom Environment
diff --git a/dial_mpc/core/dial_core.py b/dial_mpc/core/dial_core.py
index 1c34cb7..901aca6 100644
--- a/dial_mpc/core/dial_core.py
+++ b/dial_mpc/core/dial_core.py
@@ -254,7 +254,7 @@ def reverse_scan(rng_Y0_state, factor):
t0 = time.time()
traj_diffuse_factors = (
- dial_config.traj_diffuse_factor ** (jnp.arange(n_diffuse))[:, None]
+ mbdpi.sigma_control * dial_config.traj_diffuse_factor ** (jnp.arange(n_diffuse))[:, None]
)
(rng, Y0, _), info = jax.lax.scan(
reverse_scan, (rng, Y0, state), traj_diffuse_factors
diff --git a/dial_mpc/deploy/dial_real.py b/dial_mpc/deploy/dial_real.py
index e69de29..e81abc8 100644
--- a/dial_mpc/deploy/dial_real.py
+++ b/dial_mpc/deploy/dial_real.py
@@ -0,0 +1,382 @@
+import os
+import time
+import csv
+import sys
+import importlib
+from multiprocessing import shared_memory
+from threading import Thread
+from typing import List, Union
+from dataclasses import dataclass
+
+import mujoco
+import mujoco.viewer
+import numpy as np
+import argparse
+from scipy.spatial.transform import Rotation as R
+import art
+import yaml
+
+from unitree_sdk2py.core.channel import (
+ ChannelSubscriber,
+ ChannelFactoryInitialize,
+ ChannelPublisher,
+)
+from unitree_sdk2py.idl.default import (
+ unitree_go_msg_dds__LowState_,
+ unitree_go_msg_dds__LowCmd_,
+)
+from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowState_
+from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_
+from unitree_sdk2py.utils.crc import CRC
+from unitree_sdk2py.utils.thread import Thread
+
+from dial_mpc.config.base_env_config import BaseEnvConfig
+from dial_mpc.core.dial_config import DialConfig
+import dial_mpc.utils.unitree_legged_const as unitree
+from dial_mpc.utils.io_utils import (
+ load_dataclass_from_dict,
+ get_model_path,
+ get_example_path,
+)
+from dial_mpc.examples import deploy_examples
+from dial_mpc.deploy.localization import load_plugin, get_available_plugins
+
+
+@dataclass
+class DialRealConfig:
+ robot_name: str
+ scene_name: str
+ real_leg_control: str
+ record: bool
+ network_interface: str
+ real_kp: Union[float, List[float]]
+ real_kd: Union[float, List[float]]
+ initial_position_ctrl: List[float]
+ low_cmd_pub_dt: float
+ localization_plugin: str
+ localization_timeout_sec: float
+
+
+class DialReal:
+ def __init__(
+ self,
+ real_config: DialRealConfig,
+ env_config: BaseEnvConfig,
+ dial_config: DialConfig,
+ plugin_config: dict,
+ ):
+ self.leg_control = real_config.real_leg_control
+ if self.leg_control != "position" and self.leg_control != "torque":
+ raise ValueError("Invalid leg control mode")
+ self.record = real_config.record
+ self.data = []
+ # control related
+ self.kp = real_config.real_kp
+ self.kd = real_config.real_kd
+ self.current_kp = 0.0
+ self.mocap_odom = None
+ self.ctrl_dt = env_config.dt
+ self.n_acts = dial_config.Hsample + 1
+ self.t = 0.0
+ self.stand_ctrl = np.array(real_config.initial_position_ctrl, dtype=np.float32)
+ self.low_cmd_pub_dt = real_config.low_cmd_pub_dt
+
+ # load localization plugin
+ self.localization_plugin = load_plugin(real_config.localization_plugin)
+ if self.localization_plugin is None:
+ raise ValueError(
+ f'Failed to load localization plugin "{real_config.localization_plugin}". Please see error messages above. Valid plugins are: {get_available_plugins()}'
+ )
+ self.localization_plugin = self.localization_plugin(plugin_config)
+ self.localization_timeout_sec = real_config.localization_timeout_sec
+
+ # mujoco setup
+ self.mj_model = mujoco.MjModel.from_xml_path(
+ get_model_path(real_config.robot_name, real_config.scene_name).as_posix()
+ )
+ self.mj_model.opt.timestep = real_config.low_cmd_pub_dt
+ self.mj_data = mujoco.MjData(self.mj_model)
+ mujoco.mj_resetDataKeyframe(self.mj_model, self.mj_data, 0)
+ mujoco.mj_forward(self.mj_model, self.mj_data)
+ self.viewer = mujoco.viewer.launch_passive(
+ self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=True
+ )
+
+ # parameters
+ self.Nx = self.mj_model.nq + self.mj_model.nv
+ self.Nu = self.mj_model.nu
+
+ # get home keyframe
+ self.default_q = self.mj_model.keyframe("home").qpos
+ self.default_u = self.mj_model.keyframe("home").ctrl
+
+ # communication setup
+ # publisher
+ self.time_shm = shared_memory.SharedMemory(
+ name="time_shm", create=True, size=32
+ )
+ self.time_shared = np.ndarray(1, dtype=np.float32, buffer=self.time_shm.buf)
+ self.time_shared[0] = 0.0
+ self.state_shm = shared_memory.SharedMemory(
+ name="state_shm", create=True, size=self.Nx * 32
+ )
+ self.state_shared = np.ndarray(
+ (self.Nx,), dtype=np.float32, buffer=self.state_shm.buf
+ )
+ # listener
+ self.acts_shm = shared_memory.SharedMemory(
+ name="acts_shm", create=True, size=self.n_acts * self.Nu * 32
+ )
+ self.acts_shared = np.ndarray(
+ (self.n_acts, self.mj_model.nu), dtype=np.float32, buffer=self.acts_shm.buf
+ )
+ self.acts_shared[:] = self.default_u
+ self.refs_shm = shared_memory.SharedMemory(
+ name="refs_shm", create=True, size=self.n_acts * self.Nu * 3 * 32
+ )
+ self.refs_shared = np.ndarray(
+ (self.n_acts, self.Nu, 3), dtype=np.float32, buffer=self.refs_shm.buf
+ )
+ self.refs_shared[:] = 1.0
+ self.plan_time_shm = shared_memory.SharedMemory(
+ name="plan_time_shm", create=True, size=32
+ )
+ self.plan_time_shared = np.ndarray(
+ 1, dtype=np.float32, buffer=self.plan_time_shm.buf
+ )
+ self.plan_time_shared[0] = -self.ctrl_dt
+
+ self.tau_shm = shared_memory.SharedMemory(
+ name="tau_shm", create=True, size=self.n_acts * self.Nu * 32
+ )
+ self.tau_shared = np.ndarray(
+ (self.n_acts, self.mj_model.nu), dtype=np.float32, buffer=self.tau_shm.buf
+ )
+
+ # unitree pubs and subs
+ self.crc = CRC()
+ ChannelFactoryInitialize(0, real_config.network_interface)
+ self.low_pub = ChannelPublisher("rt/lowcmd", LowCmd_)
+ self.low_pub.Init()
+ self.low_cmd = unitree_go_msg_dds__LowCmd_()
+ self.low_cmd.head[0] = 0xFE
+ self.low_cmd.head[1] = 0xEF
+ self.low_cmd.level_flag = 0xFF
+ self.low_cmd.gpio = 0
+ for i in range(20):
+ self.low_cmd.motor_cmd[i].mode = 0x01 # (PMSM) mode
+ self.low_cmd.motor_cmd[i].q = unitree.PosStopF
+ self.low_cmd.motor_cmd[i].kp = 0
+ self.low_cmd.motor_cmd[i].dq = unitree.VelStopF
+ self.low_cmd.motor_cmd[i].kd = 0
+ self.low_cmd.motor_cmd[i].tau = 0
+ self.low_sub = ChannelSubscriber("rt/lowstate", LowState_)
+ self.low_sub.Init(self.on_low_state, 1)
+
+ # visualization thread
+ self.vis_thread = Thread(target=self.visualize)
+ self.vis_thread.Start()
+
+ def visualize(self):
+ while True:
+ mujoco.mj_step(self.mj_model, self.mj_data)
+ self.viewer.sync()
+ time.sleep(0.05)
+
+ def on_low_state(self, msg: LowState_):
+ localization_output = self.localization_plugin.get_state()
+ if localization_output is None:
+ return
+ now = time.time()
+ localization_time = self.localization_plugin.get_last_update_time()
+ if now - localization_time > self.localization_timeout_sec:
+ print(f"[WARN] Localization plugin timeout: {now - localization_time} s")
+ return
+
+ q = np.zeros(self.mj_model.nq)
+ dq = np.zeros(self.mj_model.nv)
+
+ # copy body pose and velocity from localization plugin
+ q[:7] = localization_output[:7]
+ dq[0:3] = localization_output[7:10]
+
+ # rotate angular velocity into the world frame
+ rot = R.from_quat([q[4], q[5], q[6], q[3]]).as_matrix()
+ # ang_vel_body = np.array([self.mocap_odom.twist.twist.angular.x, self.mocap_odom.twist.twist.angular.y, self.mocap_odom.twist.twist.angular.z])
+ ang_vel_body = np.array([msg.imu_state.gyroscope]).flatten()
+ ang_vel_world = rot @ ang_vel_body
+ dq[3:6] = ang_vel_world
+
+ # update joint positions and velocities
+ for i in range(12):
+ q[7 + i] = msg.motor_state[i].q
+ dq[6 + i] = msg.motor_state[i].dq
+
+ state = np.concatenate([q, dq])
+ self.state_shared[:] = state
+ self.mj_data.qpos = q
+ self.mj_data.qvel = dq
+
+ def main_loop(self):
+ while True:
+ t0 = time.time()
+ if self.plan_time_shared[0] < 0.0:
+ self.mj_data.ctrl = self.stand_ctrl
+ else:
+ delta_time = self.t - self.plan_time_shared[0]
+ delta_step = int(delta_time / self.ctrl_dt)
+ if delta_step >= self.n_acts or delta_step < 0:
+ delta_step = self.n_acts - 1
+ self.mj_data.ctrl = self.acts_shared[delta_step]
+ taus = self.tau_shared[delta_step].copy()
+
+ # mujoco.mj_step(self.mj_model, self.mj_data)
+ self.t += self.low_cmd_pub_dt
+ self.time_shared[:] = self.t
+
+ # self.data.append(np.concatenate([[time.time()], self.mj_data.qpos, self.mj_data.qvel, self.mj_data.ctrl]))
+
+ # publish control
+ for i in range(12):
+ if self.plan_time_shared[0] < 0.0 or self.leg_control == "position":
+ self.low_cmd.motor_cmd[i].q = self.mj_data.ctrl[i]
+ self.low_cmd.motor_cmd[i].kp = (
+ min(self.current_kp, self.kp)
+ if type(self.kp) is float
+ else min(self.current_kp, self.kp[i])
+ )
+ self.low_cmd.motor_cmd[i].dq = 0
+ self.low_cmd.motor_cmd[i].kd = (
+ self.kd if type(self.kd) is float else self.kd[i]
+ )
+ self.low_cmd.motor_cmd[i].tau = 0
+ self.current_kp += 0.005 # ramp up kp to start the robot smoothly
+ else:
+ self.low_cmd.motor_cmd[i].q = 0.0
+ self.low_cmd.motor_cmd[i].kp = 0.0
+ self.low_cmd.motor_cmd[i].dq = 0.0
+ self.low_cmd.motor_cmd[i].kd = (
+ self.kd if type(self.kd) is float else self.kd[i]
+ )
+ self.low_cmd.motor_cmd[i].tau = taus[i] * 1.0
+ self.low_cmd.crc = self.crc.Crc(self.low_cmd)
+ self.low_pub.Write(self.low_cmd)
+
+ if self.plan_time_shared[0] >= 0.0 and self.record:
+ self.data.append(
+ np.concatenate(
+ [
+ [time.time()],
+ self.mj_data.qpos,
+ self.mj_data.qvel,
+ self.mj_data.ctrl,
+ ]
+ )
+ )
+
+ t1 = time.time()
+ duration = t1 - t0
+ if duration < self.low_cmd_pub_dt:
+ time.sleep(self.low_cmd_pub_dt - duration)
+ else:
+ print(f"[WARN] Real loop overruns: {duration * 1000} ms")
+
+ def close(self):
+ self.time_shm.close()
+ self.time_shm.unlink()
+ self.state_shm.close()
+ self.state_shm.unlink()
+ self.acts_shm.close()
+ self.acts_shm.unlink()
+ self.plan_time_shm.close()
+ self.plan_time_shm.unlink()
+
+
+def main(args=None):
+ art.tprint("LeCAR @ CMU\nDIAL-MPC\nREAL", font="big", chr_ignore=True)
+ parser = argparse.ArgumentParser()
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "--config", type=str, default="config.yaml", help="Path to config file"
+ )
+ group.add_argument(
+ "--example",
+ type=str,
+ default=None,
+ help="Example to run",
+ )
+ group.add_argument(
+ "--list-examples",
+ action="store_true",
+ help="List available examples",
+ )
+ parser.add_argument(
+ "--custom-env",
+ type=str,
+ default=None,
+ help="Custom environment to import dynamically",
+ )
+ parser.add_argument(
+ "--network-interface",
+ type=str,
+ default=None,
+ help="Network interface override",
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default=None,
+ help="Custom localization plugin to import dynamically",
+ )
+ args = parser.parse_args()
+
+ if args.custom_env is not None:
+ sys.path.append(os.getcwd())
+ importlib.import_module(args.custom_env)
+ if args.plugin is not None:
+ sys.path.append(os.getcwd())
+ importlib.import_module(args.plugin)
+
+ if args.list_examples:
+ print("Available examples:")
+ for example in deploy_examples:
+ print(f" - {example}")
+ return
+ if args.example is not None:
+ if args.example not in deploy_examples:
+ print(f"Example {args.example} not found.")
+ return
+ config_dict = yaml.safe_load(
+ open(get_example_path(args.example + ".yaml"), "r")
+ )
+ else:
+ config_dict = yaml.safe_load(open(args.config, "r"))
+
+ if args.network_interface is not None:
+ config_dict["network_interface"] = args.network_interface
+
+ real_config = load_dataclass_from_dict(DialRealConfig, config_dict)
+ env_config = load_dataclass_from_dict(BaseEnvConfig, config_dict)
+ dial_config = load_dataclass_from_dict(DialConfig, config_dict)
+ real_env = DialReal(real_config, env_config, dial_config, config_dict)
+
+ try:
+ real_env.main_loop()
+ except KeyboardInterrupt:
+ pass
+ finally:
+ if real_env.record:
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
+ data = np.array(real_env.data)
+ output_dir = os.path.join(
+ dial_config.output_dir,
+ f"sim_{dial_config.env_name}_{env_config.task_name}_{timestamp}",
+ )
+ os.makedirs(output_dir)
+ np.save(os.path.join(output_dir, "states"), data)
+
+ real_env.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dial_mpc/deploy/localization/__init__.py b/dial_mpc/deploy/localization/__init__.py
new file mode 100644
index 0000000..c5742de
--- /dev/null
+++ b/dial_mpc/deploy/localization/__init__.py
@@ -0,0 +1,89 @@
+import threading
+import pkgutil
+import os
+import importlib
+
+plugin_registry = {}
+_registry_lock = threading.Lock()
+
+def get_available_plugins():
+ with _registry_lock:
+ return list(plugin_registry.keys())
+
+def discover_builtin_plugins():
+ plugin_path = os.path.dirname(__file__)
+ for finder, name, ispkg in pkgutil.iter_modules([plugin_path]):
+ if name not in plugin_registry and name != 'base_plugin':
+ plugin_registry[name] = None # Placeholder for lazy loading
+
+discover_builtin_plugins()
+
+def register_plugin(name, plugin_cls=None, module_path=None):
+ with _registry_lock:
+ if name in plugin_registry:
+ raise ValueError(f"Plugin '{name}' is already registered.")
+
+ if plugin_cls:
+ # Ensure the plugin class is a subclass of BaseLocalizationPlugin
+ from .base_plugin import BaseLocalizationPlugin
+ if not issubclass(plugin_cls, BaseLocalizationPlugin):
+ raise TypeError("The plugin class must inherit from BaseLocalizationPlugin.")
+ plugin_registry[name] = plugin_cls
+
+ elif module_path:
+ # Dynamically load the module from the given path
+ import importlib.util
+ spec = importlib.util.spec_from_file_location(name, module_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+
+ # Expect the module to define 'BaseLocalizationPlugin' class
+ plugin_cls = getattr(module, 'BaseLocalizationPlugin', None)
+ if not plugin_cls:
+ raise AttributeError(f"No 'BaseLocalizationPlugin' class found in '{module_path}'.")
+ from .base_plugin import BaseLocalizationPlugin
+ if not issubclass(plugin_cls, BaseLocalizationPlugin):
+ raise TypeError("The plugin class must inherit from BaseLocalizationPlugin.")
+ plugin_registry[name] = plugin_cls
+
+ else:
+ raise ValueError("You must provide either 'plugin_cls' or 'module_path'.")
+
+def load_plugin(plugin_name):
+ with _registry_lock:
+ plugin_cls = plugin_registry.get(plugin_name)
+
+ if plugin_cls is None:
+ # Lazy loading of built-in plugins
+ try:
+ module = importlib.import_module(f".{plugin_name}", package=__package__)
+
+ # Find the subclass of BaseLocalizationPlugin in the module
+ from .base_plugin import BaseLocalizationPlugin
+ plugin_classes = [
+ attr for attr in vars(module).values()
+ if isinstance(attr, type) and issubclass(attr, BaseLocalizationPlugin) and attr is not BaseLocalizationPlugin
+ ]
+
+ if not plugin_classes:
+ print(f"No subclass of BaseLocalizationPlugin found in module '{plugin_name}'.")
+ return None
+ elif len(plugin_classes) == 1:
+ plugin_cls = plugin_classes[0]
+ plugin_registry[plugin_name] = plugin_cls
+ else:
+ print(f"Multiple subclasses of BaseLocalizationPlugin found in module '{plugin_name}'.")
+ print("Each plugin module must contain exactly one plugin class.")
+ return None
+ except ImportError as e:
+ print(f"Failed to import plugin '{plugin_name}': {e}")
+ return None
+ elif plugin_cls is not None:
+ # Plugin class already loaded
+ pass
+ else:
+ print(f"Plugin '{plugin_name}' is not registered.")
+ return None
+
+ # Return the plugin type
+ return plugin_cls
diff --git a/dial_mpc/deploy/localization/base_plugin.py b/dial_mpc/deploy/localization/base_plugin.py
new file mode 100644
index 0000000..a1de0ea
--- /dev/null
+++ b/dial_mpc/deploy/localization/base_plugin.py
@@ -0,0 +1,22 @@
+from typing import Dict, Any
+
+
+class BaseLocalizationPlugin:
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+
+ def get_state(self):
+ """
+ Returns base qpos (3+4) and qvel (3+3) in a 1D array of size 13.
+ Returns None if no update has been received.
+ ALL VELOCITIES MUST BE RETURNED IN WORLD FRAME.
+ """
+ raise NotImplementedError
+
+ def get_last_update_time(self):
+ """
+ Returns the timestamp (float) of the last update.
+ Returns None if no update has been received.
+ Used to check if the plugin is still alive.
+ """
+ raise NotImplementedError
diff --git a/dial_mpc/deploy/localization/ros2_odometry_plugin.py b/dial_mpc/deploy/localization/ros2_odometry_plugin.py
new file mode 100644
index 0000000..4d587dd
--- /dev/null
+++ b/dial_mpc/deploy/localization/ros2_odometry_plugin.py
@@ -0,0 +1,65 @@
+import numpy as np
+from scipy.spatial.transform import Rotation as R
+
+import rclpy
+from rclpy.node import Node
+from nav_msgs.msg import Odometry
+
+from dial_mpc.deploy.localization.base_plugin import BaseLocalizationPlugin
+
+
+class ROS2OdometryPlugin(BaseLocalizationPlugin, Node):
+ def __init__(self, config):
+ BaseLocalizationPlugin.__init__(self, config)
+ rclpy.init()
+ Node.__init__(self, "ros2_odom_plugin")
+ self.subscription = self.create_subscription(
+ Odometry, config["odom_topic"], self.odom_callback, 1
+ )
+
+ self.qpos = None
+ self.qvel = None
+ self.last_time = None
+
+ def __del__(self):
+ rclpy.shutdown()
+
+ def odom_callback(self, msg):
+ qpos = np.array(
+ [
+ msg.pose.pose.position.x,
+ msg.pose.pose.position.y,
+ msg.pose.pose.position.z,
+ msg.pose.pose.orientation.w,
+ msg.pose.pose.orientation.x,
+ msg.pose.pose.orientation.y,
+ msg.pose.pose.orientation.z,
+ ]
+ )
+ vb = np.array(
+ [
+ msg.twist.twist.linear.x,
+ msg.twist.twist.linear.y,
+ msg.twist.twist.linear.z,
+ ]
+ )
+ ab = np.array(
+ [
+ msg.twist.twist.angular.x,
+ msg.twist.twist.angular.y,
+ msg.twist.twist.angular.z,
+ ]
+ )
+ # rotate velocities to world frame
+ q = R.from_quat([qpos[3], qpos[4], qpos[5], qpos[6]])
+ vw = q.apply(vb)
+ aw = q.apply(ab)
+ self.qpos = qpos
+ self.qvel = np.concatenate([vw, aw])
+ self.last_time = msg.header.stamp.sec + msg.header.stamp.nanosec * 1e-9
+
+ def get_state(self):
+ return np.concatenate([self.qpos, self.qvel]) if self.qpos is not None else None
+
+ def get_last_update_time(self):
+ return self.last_time
diff --git a/dial_mpc/deploy/localization/vicon_shm_plugin.py b/dial_mpc/deploy/localization/vicon_shm_plugin.py
new file mode 100644
index 0000000..9312f6e
--- /dev/null
+++ b/dial_mpc/deploy/localization/vicon_shm_plugin.py
@@ -0,0 +1,221 @@
+import struct
+import time
+from multiprocessing import shared_memory
+import threading
+
+import numpy as np
+from scipy.spatial.transform import Rotation as R
+from scipy.signal import butter, lfilter
+from pyvicon_datastream import tools
+
+from dial_mpc.deploy.localization.base_plugin import BaseLocalizationPlugin
+
+class ViconDemo:
+ def __init__(self, vicon_tracker_ip, vicon_object_name, vicon_z_offset):
+ # Vicon DataStream IP and object name
+ self.vicon_tracker_ip = vicon_tracker_ip
+ self.vicon_object_name = vicon_object_name
+ self.vicon_z_offset = vicon_z_offset
+ # Connect to Vicon DataStream
+ self.tracker = tools.ObjectTracker(self.vicon_tracker_ip)
+ if self.tracker.is_connected:
+ print(f"Connected to Vicon DataStream at {self.vicon_tracker_ip}")
+ else:
+ print(f"Failed to connect to Vicon DataStream at {self.vicon_tracker_ip}")
+ raise Exception(f"Connection to {self.vicon_tracker_ip} failed")
+
+ # Initialize previous values for velocity computation
+ self.prev_time = None
+ self.prev_position = None
+ self.prev_quaternion = None
+
+ # Low-pass filter parameters
+ self.cutoff_freq = 5.0 # Cut-off frequency of the filter (Hz)
+ self.filter_order = 2
+ self.fs = 100.0 # Sampling frequency (Hz)
+ self.b, self.a = butter(
+ self.filter_order, self.cutoff_freq / (0.5 * self.fs), btype="low"
+ )
+
+ # Initialize data buffers for filtering
+ self.vel_buffer = []
+ self.omega_buffer = []
+
+
+ # Initialize shared memory
+ self.shared_mem_name = "mocap_state_shm"
+ self.shared_mem_size = 8 + 13 * 8 # 8 bytes for utime (int64), 13 float64s (13*8 bytes)
+ try:
+ self.state_shm = shared_memory.SharedMemory(name=self.shared_mem_name, create=True, size=self.shared_mem_size)
+ print(f"Attach to shared memory '{self.shared_mem_name}' of size {self.shared_mem_size} bytes.")
+ except FileExistsError:
+ print(f"shared memory does not exist")
+ self.state_buffer = self.state_shm.buf
+
+ def get_vicon_data(self):
+ position = self.tracker.get_position(self.vicon_object_name)
+ if not position:
+ print(f"Cannot get the pose of `{self.vicon_object_name}`.")
+ return None, None, None
+
+ try:
+ obj = position[2][0]
+ _, _, x, y, z, roll_ext, pitch_ext, yaw_ext = obj
+ current_time = time.time()
+ # q = tf_transformations.quaternion_from_euler(roll, pitch, yaw, "rxyz")
+ # roll, pitch, yaw = tf_transformations.euler_from_quaternion(q, "sxyz")
+
+ # Position and orientation
+ position = np.array([x, y, z]) / 1000.0
+ position[2] = position[2] + self.vicon_z_offset
+ rotation = R.from_euler("XYZ", [roll_ext, pitch_ext, yaw_ext], degrees=False)
+ quaternion = rotation.as_quat() # [x, y, z, w]
+
+ return current_time, position, quaternion
+ except Exception as e:
+ print(f"Error retrieving Vicon data: {e}")
+ return None, None, None
+
+ def compute_velocities(self, current_time, position, quaternion):
+ # Initialize velocities
+ linear_velocity = np.zeros(3)
+ angular_velocity = np.zeros(3)
+
+ if (
+ self.prev_time is not None
+ and self.prev_position is not None
+ and self.prev_quaternion is not None
+ ):
+ dt = current_time - self.prev_time
+ if dt > 0:
+ # Linear velocity
+ dp = position - self.prev_position
+ linear_velocity = dp / dt
+
+ # Angular velocity
+ prev_rot = R.from_quat(self.prev_quaternion)
+ curr_rot = R.from_quat(quaternion)
+ delta_rot = curr_rot * prev_rot.inv()
+ delta_angle = delta_rot.as_rotvec()
+ angular_velocity = delta_angle / dt
+ else:
+ # First data point; velocities remain zero
+ pass
+
+ # Update previous values
+ self.prev_time = current_time
+ self.prev_position = position
+ self.prev_quaternion = quaternion
+
+ return linear_velocity, angular_velocity
+
+ def low_pass_filter(self, data_buffer, new_data):
+ # Append new data to the buffer
+ data_buffer.append(new_data)
+ # Keep only the last N samples (buffer size)
+ buffer_size = int(self.fs / self.cutoff_freq) * 3
+ if len(data_buffer) > buffer_size:
+ data_buffer.pop(0)
+ # Apply low-pass filter if enough data points are available
+ if len(data_buffer) >= self.filter_order + 1:
+ data_array = np.array(data_buffer)
+ filtered_data = lfilter(self.b, self.a, data_array, axis=0)[-1]
+ return filtered_data
+ else:
+ return new_data # Not enough data to filter; return the new data as is
+
+ def main_loop(self):
+ print("Starting Vicon data acquisition...")
+ try:
+ while True:
+ # Get Vicon data
+ current_time, position, quaternion = self.get_vicon_data()
+ if position is None:
+ time.sleep(0.01)
+ continue
+
+ # Compute velocities
+ linear_velocity, angular_velocity = self.compute_velocities(
+ current_time, position, quaternion
+ )
+
+ # Apply low-pass filter to velocities
+ filtered_linear_velocity = self.low_pass_filter(
+ self.vel_buffer, linear_velocity
+ )
+ filtered_angular_velocity = self.low_pass_filter(
+ self.omega_buffer, angular_velocity
+ )
+
+ # Prepare data to pack
+ utime = int(current_time * 1e6) # int64
+ data_to_pack = [utime]
+ data_to_pack.extend(position.tolist())
+ data_to_pack.extend(quaternion.tolist())
+ data_to_pack.extend(filtered_linear_velocity.tolist())
+ data_to_pack.extend(filtered_angular_velocity.tolist())
+
+ # Pack data into shared memory buffer
+ struct_format = "q13d"
+ struct.pack_into(struct_format, self.state_buffer, 0, *data_to_pack)
+
+ # Optionally, print or process the filtered data
+ # print(f"Position: {position}")
+ # print(f"Filtered Linear Velocity: {filtered_linear_velocity}")
+ # print(f"Filtered Angular Velocity: {filtered_angular_velocity}")
+ # print(f"Quat:", quaternion)
+ # print("-" * 50)
+
+ # print(f"State:", position)
+ # print("-" * 50)
+
+ # Sleep to mimic sampling rate
+ time.sleep(1.0 / self.fs)
+
+ except KeyboardInterrupt:
+ print("Exiting Vicon data acquisition.")
+ finally:
+ # Close and unlink shared memory
+ try:
+ self.state_shm.close()
+ print(f"Shared memory '{self.shared_mem_name}' closed")
+ except:
+ pass
+
+
+class ViconPlugin(BaseLocalizationPlugin):
+ def __init__(self, config):
+ self.time = time.time()
+ # Initialize Vicon thread
+ vicon_demo = ViconDemo(config['vicon_tracker_ip'], config['vicon_object_name'], config['vicon_z_offset'])
+ self.vicon_thread = threading.Thread(target=vicon_demo.main_loop)
+ self.vicon_thread.start()
+
+ # Initialize shared memory
+ self.shared_mem_name = "mocap_state_shm"
+ self.shared_mem_size = 8 + 13 * 8 # 8 bytes for utime (int64), 13 float64s (13*8 bytes)
+ self.mocap_shm = shared_memory.SharedMemory(name=self.shared_mem_name, create=False, size=self.shared_mem_size)
+ self.state_buffer = self.mocap_shm.buf
+
+ def get_state(self):
+ # Unpack data from shared memory
+ struct_format = "q13d"
+ data = struct.unpack_from(struct_format, self.state_buffer, 0)
+
+ # Extract position, quaternion, linear velocity, and angular velocity
+ utime = data[0]
+ position = np.array(data[1:4])
+ quaternion = np.array(data[4:8])
+ quaternion = np.roll(quaternion, 1) # change quaternion from xyzw to wxyz
+ linear_velocity = np.array(data[8:11])
+ angular_velocity = np.array(data[11:14])
+
+ # Combine position and quaternion into qpos
+ qpos = np.concatenate([position, quaternion])
+ # Combine linear and angular velocities into qvel
+ qvel = np.concatenate([linear_velocity, angular_velocity])
+ self.time = utime
+ return np.concatenate([qpos, qvel])
+
+ def get_last_update_time(self):
+ return self.time
diff --git a/dial_mpc/envs/__init__.py b/dial_mpc/envs/__init__.py
index 8abc934..4009629 100644
--- a/dial_mpc/envs/__init__.py
+++ b/dial_mpc/envs/__init__.py
@@ -2,6 +2,7 @@
from dial_mpc.envs.unitree_h1_env import (
UnitreeH1WalkEnvConfig,
UnitreeH1PushCrateEnvConfig,
+ UnitreeH1LocoEnvConfig,
)
from dial_mpc.envs.unitree_go2_env import (
UnitreeGo2EnvConfig,
@@ -12,6 +13,7 @@
_configs = {
"unitree_h1_walk": UnitreeH1WalkEnvConfig,
"unitree_h1_push_crate": UnitreeH1PushCrateEnvConfig,
+ "unitree_h1_loco": UnitreeH1LocoEnvConfig,
"unitree_go2_walk": UnitreeGo2EnvConfig,
"unitree_go2_seq_jump": UnitreeGo2SeqJumpEnvConfig,
"unitree_go2_crate_climb": UnitreeGo2CrateEnvConfig,
diff --git a/dial_mpc/envs/unitree_h1_env.py b/dial_mpc/envs/unitree_h1_env.py
index b9fb71d..8c3a84b 100644
--- a/dial_mpc/envs/unitree_h1_env.py
+++ b/dial_mpc/envs/unitree_h1_env.py
@@ -567,5 +567,340 @@ def randomize():
return state
+@dataclass
+class UnitreeH1LocoEnvConfig(BaseEnvConfig):
+ kp: Union[float, jax.Array] = field(default_factory=lambda: jnp.array(
+ [
+ 200.0,
+ 200.0,
+ 200.0, # left hips
+ 200.0,
+ 60.0, # left knee, ankle
+ 200.0,
+ 200.0,
+ 200.0, # right hips
+ 200.0,
+ 60.0, # right knee, ankle
+ 200.0, # torso
+ ]
+ ))
+ kd: Union[float, jax.Array] = field(default_factory=lambda: jnp.array(
+ [
+ 5.0,
+ 5.0,
+ 5.0, # left hips
+ 5.0,
+ 1.5, # left knee, ankle
+ 5.0,
+ 5.0,
+ 5.0, # right hips
+ 5.0,
+ 1.5, # right knee, ankle
+ 5.0, # torso
+ ]
+ ))
+ default_vx: float = 1.0
+ default_vy: float = 0.0
+ default_vyaw: float = 0.0
+ ramp_up_time: float = 2.0
+ gait: str = "jog"
+
+
+class UnitreeH1LocoEnv(BaseEnv):
+ def __init__(self, config: UnitreeH1LocoEnvConfig):
+ super().__init__(config)
+
+ # some body indices
+ self._pelvis_idx = mujoco.mj_name2id(
+ self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "pelvis"
+ )
+ self._torso_idx = mujoco.mj_name2id(
+ self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "torso_link"
+ )
+
+ self._left_foot_idx = mujoco.mj_name2id(
+ self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, "left_foot"
+ )
+ self._right_foot_idx = mujoco.mj_name2id(
+ self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, "right_foot"
+ )
+ self._feet_site_id = jnp.array(
+ [self._left_foot_idx, self._right_foot_idx], dtype=jnp.int32
+ )
+ # gait phase
+ self._gait = config.gait
+ self._gait_phase = {
+ "stand": jnp.zeros(2),
+ "slow_walk": jnp.array([0.0, 0.5]),
+ "walk": jnp.array([0.0, 0.5]),
+ "jog": jnp.array([0.0, 0.5]),
+ }
+ self._gait_params = {
+ # ratio, cadence, amplitude
+ "stand": jnp.array([1.0, 1.0, 0.0]),
+ "slow_walk": jnp.array([0.6, 0.8, 0.15]),
+ "walk": jnp.array([0.5, 1.5, 0.10]),
+ "jog": jnp.array([0.3, 2.0, 0.2]),
+ }
+
+ # joint limits and initial pose
+ self._init_q = jnp.array(self.sys.mj_model.keyframe("home").qpos)
+ self._default_pose = self.sys.mj_model.keyframe("home").qpos[7:]
+ # joint sampling range
+ self.joint_range = jnp.array(
+ [
+ [-0.2, 0.2],
+ [-0.2, 0.2],
+ [-0.6, 0.6],
+ [0.0, 1.5],
+ [-0.6, 0.4],
+
+ [-0.2, 0.2],
+ [-0.2, 0.2],
+ [-0.6, 0.6],
+ [0.0, 1.5],
+ [-0.6, 0.4],
+
+ [-0.5, 0.5],
+ ]
+ )
+ # self.joint_range = self.physical_joint_range
+
+ def make_system(self, config: UnitreeH1LocoEnvConfig) -> System:
+ model_path = get_model_path("unitree_h1", "mjx_scene_h1_loco.xml")
+ sys = mjcf.load(model_path)
+ sys = sys.tree_replace({"opt.timestep": config.timestep})
+ return sys
+
+ def reset(self, rng: jax.Array) -> State:
+ rng, key = jax.random.split(rng)
+
+ pipeline_state = self.pipeline_init(self._init_q, jnp.zeros(self._nv))
+
+ state_info = {
+ "rng": rng,
+ "pos_tar": jnp.array([0.0, 0.0, 1.3]),
+ "vel_tar": jnp.zeros(3),
+ "ang_vel_tar": jnp.zeros(3),
+ "yaw_tar": 0.0,
+ "step": 0,
+ "z_feet": jnp.zeros(2),
+ "z_feet_tar": jnp.zeros(2),
+ "randomize_target": self._config.randomize_tasks,
+ "last_contact": jnp.zeros(2, dtype=jnp.bool),
+ "feet_air_time": jnp.zeros(2),
+ }
+
+ obs = self._get_obs(pipeline_state, state_info)
+ reward, done = jnp.zeros(2)
+ metrics = {}
+ state = State(pipeline_state, obs, reward, done, metrics, state_info)
+ return state
+
+ def step(
+ self, state: State, action: jax.Array
+ ) -> State: # pytype: disable=signature-mismatch
+ rng, cmd_rng = jax.random.split(state.info["rng"], 2)
+
+ # physics step
+ joint_targets = self.act2joint(action)
+ if self._config.leg_control == "position":
+ ctrl = joint_targets
+ elif self._config.leg_control == "torque":
+ ctrl = self.act2tau(action, state.pipeline_state)
+ pipeline_state = self.pipeline_step(state.pipeline_state, ctrl)
+ x, xd = pipeline_state.x, pipeline_state.xd
+
+ # observation data
+ obs = self._get_obs(pipeline_state, state.info)
+
+ # switch to new target if randomize_target is True
+ def dont_randomize():
+ return (
+ jnp.array([self._config.default_vx, self._config.default_vy, 0.0]),
+ jnp.array([0.0, 0.0, self._config.default_vyaw]),
+ )
+
+ def randomize():
+ return self.sample_command(cmd_rng)
+
+ vel_tar, ang_vel_tar = jax.lax.cond(
+ (state.info["randomize_target"]) & (state.info["step"] % 500 == 0),
+ randomize,
+ dont_randomize,
+ )
+ state.info["vel_tar"] = jnp.minimum(
+ vel_tar * state.info["step"] * self.dt / self._config.ramp_up_time, vel_tar
+ )
+ state.info["ang_vel_tar"] = jnp.minimum(
+ ang_vel_tar * state.info["step"] * self.dt / self._config.ramp_up_time,
+ ang_vel_tar,
+ )
+
+ # reward
+ # gaits reward
+ # z_feet = pipeline_state.site_xpos[self._feet_site_id][:, 2]
+ duty_ratio, cadence, amplitude = self._gait_params[self._gait]
+ phases = self._gait_phase[self._gait]
+ z_feet_tar = get_foot_step(
+ duty_ratio, cadence, amplitude, phases, state.info["step"] * self.dt
+ )
+ # reward_gaits = -jnp.sum(((z_feet_tar - z_feet)) ** 2)
+ z_feet = jnp.array(
+ [
+ jnp.min(pipeline_state.contact.dist[0:4]),
+ jnp.min(pipeline_state.contact.dist[4:8]),
+ ]
+ )
+ # jax.debug.print("{contact_geom}", contact_geom=pipeline_state.contact.geom)
+ reward_gaits = -jnp.sum((z_feet_tar - z_feet) ** 2)
+ # foot contact data based on z-position
+ # pytype: disable=attribute-error
+ foot_pos = pipeline_state.site_xpos[self._feet_site_id]
+ foot_contact_z = foot_pos[:, 2]
+ contact = foot_contact_z < 1e-3 # a mm or less off the floor
+ contact_filt_mm = contact | state.info["last_contact"]
+ contact_filt_cm = (foot_contact_z < 3e-2) | state.info["last_contact"]
+ first_contact = (state.info["feet_air_time"] > 0) * contact_filt_mm
+ state.info["feet_air_time"] += self.dt
+ reward_air_time = jnp.sum((state.info["feet_air_time"] - 0.1) * first_contact)
+ # position reward
+ pos_tar = (
+ state.info["pos_tar"] + state.info["vel_tar"] * self.dt * state.info["step"]
+ )
+ pos = x.pos[self._torso_idx - 1]
+ reward_pos = -jnp.sum((pos - pos_tar) ** 2)
+ # stay upright reward
+ vec_tar = jnp.array([0.0, 0.0, 1.0])
+ vec = math.rotate(vec_tar, x.rot[0])
+ reward_upright = -jnp.sum(jnp.square(vec - vec_tar))
+ # yaw orientation reward
+ yaw_tar = (
+ state.info["yaw_tar"]
+ + state.info["ang_vel_tar"][2] * self.dt * state.info["step"]
+ )
+ yaw = math.quat_to_euler(x.rot[self._torso_idx - 1])[2]
+ d_yaw = yaw - yaw_tar
+ reward_yaw = -jnp.square(jnp.atan2(jnp.sin(d_yaw), jnp.cos(d_yaw)))
+ # stay to norminal pose reward
+ # reward_pose = -jnp.sum(jnp.square(joint_targets - self._default_pose))
+ # velocity reward
+ vb = global_to_body_velocity(
+ xd.vel[self._torso_idx - 1], x.rot[self._torso_idx - 1]
+ )
+ ab = global_to_body_velocity(
+ xd.ang[self._torso_idx - 1] * jnp.pi / 180.0, x.rot[self._torso_idx - 1]
+ )
+ reward_vel = -jnp.sum((vb[:2] - state.info["vel_tar"][:2]) ** 2)
+ reward_ang_vel = -jnp.sum((ab - state.info["ang_vel_tar"]) ** 2)
+ # height reward
+ reward_height = -jnp.sum(
+ (x.pos[self._torso_idx - 1, 2] - state.info["pos_tar"][2]) ** 2
+ )
+ # foot level reward
+ left_foot_mat = pipeline_state.site_xmat[self._left_foot_idx]
+ right_foot_mat = pipeline_state.site_xmat[self._right_foot_idx]
+ vec_left = left_foot_mat @ vec_tar
+ vec_right = right_foot_mat @ vec_tar
+ reward_foot_level = -jnp.sum((vec_left - vec_tar) ** 2 + (vec_right - vec_tar) ** 2)
+ # energy reward
+ reward_energy = -jnp.sum((ctrl / self.joint_torque_range[:, 1] * pipeline_state.qvel[6:6+len(self.joint_range)] / 160.0) ** 2)
+ #reward_energy = -jnp.sum((ctrl / self.joint_torque_range[:, 1]) ** 2)
+ # stay alive reward
+ reward_alive = 1.0 - state.done
+ # reward
+ reward = (
+ reward_gaits * 10.0
+ + reward_air_time * 0.0
+ + reward_pos * 0.0
+ + reward_upright * 0.5
+ + reward_yaw * 0.5
+ # + reward_pose * 0.0
+ + reward_vel * 1.0
+ + reward_ang_vel * 1.0
+ + reward_height * 0.5
+ + reward_foot_level * 0.02
+ + reward_energy * 0.01
+ + reward_alive * 0.0
+ )
+
+ # done
+ up = jnp.array([0.0, 0.0, 1.0])
+ joint_angles = pipeline_state.q[7:]
+ joint_angles = joint_angles[: len(self.joint_range)]
+ done = jnp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
+ done |= jnp.any(joint_angles < self.joint_range[:, 0])
+ done |= jnp.any(joint_angles > self.joint_range[:, 1])
+ done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18
+ done = done.astype(jnp.float32)
+
+ # state management
+ state.info["step"] += 1
+ state.info["rng"] = rng
+ state.info["z_feet"] = z_feet
+ state.info["z_feet_tar"] = z_feet_tar
+ state.info["feet_air_time"] *= ~contact_filt_mm
+ state.info["last_contact"] = contact
+
+ state = state.replace(
+ pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
+ )
+ return state
+
+ def _get_obs(
+ self,
+ pipeline_state: base.State,
+ state_info: dict[str, Any],
+ ) -> jax.Array:
+ x, xd = pipeline_state.x, pipeline_state.xd
+ vb = global_to_body_velocity(
+ xd.vel[self._torso_idx - 1], x.rot[self._torso_idx - 1]
+ )
+ ab = global_to_body_velocity(
+ xd.ang[self._torso_idx - 1] * jnp.pi / 180.0, x.rot[self._torso_idx - 1]
+ )
+ obs = jnp.concatenate(
+ [
+ state_info["vel_tar"],
+ state_info["ang_vel_tar"],
+ pipeline_state.ctrl,
+ pipeline_state.qpos,
+ vb,
+ ab,
+ pipeline_state.qvel[6:],
+ ]
+ )
+ return obs
+
+ def render(
+ self,
+ trajectory: List[base.State],
+ camera: str | None = None,
+ width: int = 240,
+ height: int = 320,
+ ) -> Sequence[np.ndarray]:
+ camera = camera or "track"
+ return super().render(trajectory, camera=camera, width=width, height=height)
+
+ def sample_command(self, rng: jax.Array) -> tuple[jax.Array, jax.Array]:
+ lin_vel_x = [-1.5, 1.5] # min max [m/s]
+ lin_vel_y = [-0.5, 0.5] # min max [m/s]
+ ang_vel_yaw = [-1.5, 1.5] # min max [rad/s]
+
+ _, key1, key2, key3 = jax.random.split(rng, 4)
+ lin_vel_x = jax.random.uniform(
+ key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]
+ )
+ lin_vel_y = jax.random.uniform(
+ key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]
+ )
+ ang_vel_yaw = jax.random.uniform(
+ key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]
+ )
+ new_lin_vel_cmd = jnp.array([lin_vel_x[0], lin_vel_y[0], 0.0])
+ new_ang_vel_cmd = jnp.array([0.0, 0.0, ang_vel_yaw[0]])
+ return new_lin_vel_cmd, new_ang_vel_cmd
+
brax_envs.register_environment("unitree_h1_walk", UnitreeH1WalkEnv)
brax_envs.register_environment("unitree_h1_push_crate", UnitreeH1PushCrateEnv)
+brax_envs.register_environment("unitree_h1_loco", UnitreeH1LocoEnv)
diff --git a/dial_mpc/examples/__init__.py b/dial_mpc/examples/__init__.py
index 70c7dee..f747a20 100644
--- a/dial_mpc/examples/__init__.py
+++ b/dial_mpc/examples/__init__.py
@@ -1,6 +1,7 @@
examples = [
"unitree_h1_jog",
"unitree_h1_push_crate",
+ "unitree_h1_loco",
"unitree_go2_trot",
"unitree_go2_seq_jump",
"unitree_go2_crate_climb",
@@ -9,4 +10,5 @@
deploy_examples = [
"unitree_go2_trot_deploy",
"unitree_go2_seq_jump_deploy",
+ "unitree_h1_loco_deploy",
]
diff --git a/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml b/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml
index f6fda2f..ddbd060 100644
--- a/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml
+++ b/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml
@@ -41,3 +41,22 @@ record: false
real_time_factor: 1.0
sim_dt: 0.005
sync_mode: false
+
+# Real
+real_leg_control: torque
+network_interface: "enp2s0"
+real_kp: 30.0
+real_kd: 0.65
+initial_position_ctrl:
+ [0.0, 0.67, -1.3, 0.0, 0.67, -1.3, 0.0, 0.67, -1.3, 0.0, 0.67, -1.3]
+low_cmd_pub_dt: 0.002
+localization_plugin: "ros2_odometry_plugin"
+localization_timeout_sec: 0.1
+
+# ROS2 odometry plugin
+odom_topic: "odometry"
+
+# Vicon plugin
+vicon_tracker_ip: "128.2.184.3"
+vicon_object_name: "lecar_go2"
+vicon_z_offset: 0.0
diff --git a/dial_mpc/examples/unitree_go2_trot_deploy.yaml b/dial_mpc/examples/unitree_go2_trot_deploy.yaml
index 0a11b4e..c77cde8 100644
--- a/dial_mpc/examples/unitree_go2_trot_deploy.yaml
+++ b/dial_mpc/examples/unitree_go2_trot_deploy.yaml
@@ -36,3 +36,22 @@ record: false
real_time_factor: 1.0
sim_dt: 0.005
sync_mode: false
+
+# Real
+real_leg_control: torque
+network_interface: "enp2s0"
+real_kp: 30.0
+real_kd: 0.65
+initial_position_ctrl:
+ [0.0, 0.67, -1.3, 0.0, 0.67, -1.3, 0.0, 0.67, -1.3, 0.0, 0.67, -1.3]
+low_cmd_pub_dt: 0.002
+localization_plugin: "ros2_odometry_plugin"
+localization_timeout_sec: 0.1
+
+# ROS2 odometry plugin
+odom_topic: "odometry"
+
+# Vicon plugin
+vicon_tracker_ip: "128.2.184.3"
+vicon_object_name: "lecar_go2"
+vicon_z_offset: 0.0
diff --git a/dial_mpc/examples/unitree_h1_loco.yaml b/dial_mpc/examples/unitree_h1_loco.yaml
new file mode 100644
index 0000000..41ccaf4
--- /dev/null
+++ b/dial_mpc/examples/unitree_h1_loco.yaml
@@ -0,0 +1,28 @@
+# DIAL-MPC
+seed: 0
+output_dir: unitree_h1_loco
+n_steps: 400
+
+env_name: unitree_h1_loco
+Nsample: 2048
+Hsample: 20
+Hnode: 5
+Ndiffuse: 1
+Ndiffuse_init: 10
+temp_sample: 0.05
+horizon_diffuse_factor: 0.9
+traj_diffuse_factor: 0.2
+update_method: mppi
+
+# Base environment
+dt: 0.02
+timestep: 0.02
+leg_control: torque
+action_scale: 1.0
+
+# H1
+default_vx: 0.6
+default_vy: 0.0
+default_vyaw: 0.0
+ramp_up_time: 3.0
+gait: walk
diff --git a/dial_mpc/examples/unitree_h1_loco_deploy.yaml b/dial_mpc/examples/unitree_h1_loco_deploy.yaml
new file mode 100644
index 0000000..f77ea67
--- /dev/null
+++ b/dial_mpc/examples/unitree_h1_loco_deploy.yaml
@@ -0,0 +1,38 @@
+# DIAL-MPC
+seed: 0
+output_dir: unitree_h1_loco
+n_steps: 400
+
+env_name: unitree_h1_loco
+Nsample: 3000
+Hsample: 20
+Hnode: 5
+Ndiffuse: 1
+Ndiffuse_init: 10
+temp_sample: 0.04
+horizon_diffuse_factor: 0.95
+traj_diffuse_factor: 0.2
+update_method: mppi
+
+# Base environment
+dt: 0.02
+timestep: 0.02
+leg_control: torque
+action_scale: 1.0
+
+# H1
+default_vx: 0.8
+default_vy: 0.0
+default_vyaw: 0.0
+ramp_up_time: 3.0
+gait: walk
+
+# Sim
+robot_name: "unitree_h1"
+scene_name: "scene_h1_loco.xml"
+sim_leg_control: torque
+plot: false
+record: false
+real_time_factor: 1.0
+sim_dt: 0.005
+sync_mode: false
diff --git a/dial_mpc/models/unitree_h1/h1_loco.xml b/dial_mpc/models/unitree_h1/h1_loco.xml
new file mode 100644
index 0000000..f54b361
--- /dev/null
+++ b/dial_mpc/models/unitree_h1/h1_loco.xml
@@ -0,0 +1,232 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dial_mpc/models/unitree_h1/h1_real_feet.xml b/dial_mpc/models/unitree_h1/h1_real_feet.xml
index e894417..d31cc98 100644
--- a/dial_mpc/models/unitree_h1/h1_real_feet.xml
+++ b/dial_mpc/models/unitree_h1/h1_real_feet.xml
@@ -1,8 +1,6 @@
-
-
@@ -13,25 +11,22 @@
+
-
@@ -94,9 +89,8 @@
diaginertia="0.00220848 0.00218961 0.000214202"/>
-
-
-
+
+