From 7ef5cb64a6f9a2e6c8d3a2737210804ade03074b Mon Sep 17 00:00:00 2001 From: Haoru Xue Date: Wed, 30 Oct 2024 15:25:05 -0700 Subject: [PATCH] add h1 locomotion env --- dial_mpc/core/dial_core.py | 2 +- dial_mpc/envs/__init__.py | 2 + dial_mpc/envs/unitree_h1_env.py | 335 ++++++++++++++++++ dial_mpc/examples/__init__.py | 2 + dial_mpc/examples/unitree_h1_loco.yaml | 28 ++ dial_mpc/examples/unitree_h1_loco_deploy.yaml | 38 ++ dial_mpc/models/unitree_h1/h1_loco.xml | 232 ++++++++++++ dial_mpc/models/unitree_h1/h1_real_feet.xml | 27 +- dial_mpc/models/unitree_h1/mjx_h1_loco.xml | 233 ++++++++++++ .../models/unitree_h1/mjx_scene_h1_loco.xml | 33 ++ dial_mpc/models/unitree_h1/scene_h1_loco.xml | 33 ++ 11 files changed, 947 insertions(+), 18 deletions(-) create mode 100644 dial_mpc/examples/unitree_h1_loco.yaml create mode 100644 dial_mpc/examples/unitree_h1_loco_deploy.yaml create mode 100644 dial_mpc/models/unitree_h1/h1_loco.xml create mode 100644 dial_mpc/models/unitree_h1/mjx_h1_loco.xml create mode 100644 dial_mpc/models/unitree_h1/mjx_scene_h1_loco.xml create mode 100644 dial_mpc/models/unitree_h1/scene_h1_loco.xml diff --git a/dial_mpc/core/dial_core.py b/dial_mpc/core/dial_core.py index 1c34cb7..901aca6 100644 --- a/dial_mpc/core/dial_core.py +++ b/dial_mpc/core/dial_core.py @@ -254,7 +254,7 @@ def reverse_scan(rng_Y0_state, factor): t0 = time.time() traj_diffuse_factors = ( - dial_config.traj_diffuse_factor ** (jnp.arange(n_diffuse))[:, None] + mbdpi.sigma_control * dial_config.traj_diffuse_factor ** (jnp.arange(n_diffuse))[:, None] ) (rng, Y0, _), info = jax.lax.scan( reverse_scan, (rng, Y0, state), traj_diffuse_factors diff --git a/dial_mpc/envs/__init__.py b/dial_mpc/envs/__init__.py index 8abc934..4009629 100644 --- a/dial_mpc/envs/__init__.py +++ b/dial_mpc/envs/__init__.py @@ -2,6 +2,7 @@ from dial_mpc.envs.unitree_h1_env import ( UnitreeH1WalkEnvConfig, UnitreeH1PushCrateEnvConfig, + UnitreeH1LocoEnvConfig, ) from dial_mpc.envs.unitree_go2_env import ( UnitreeGo2EnvConfig, @@ -12,6 +13,7 @@ _configs = { "unitree_h1_walk": UnitreeH1WalkEnvConfig, "unitree_h1_push_crate": UnitreeH1PushCrateEnvConfig, + "unitree_h1_loco": UnitreeH1LocoEnvConfig, "unitree_go2_walk": UnitreeGo2EnvConfig, "unitree_go2_seq_jump": UnitreeGo2SeqJumpEnvConfig, "unitree_go2_crate_climb": UnitreeGo2CrateEnvConfig, diff --git a/dial_mpc/envs/unitree_h1_env.py b/dial_mpc/envs/unitree_h1_env.py index b9fb71d..8c3a84b 100644 --- a/dial_mpc/envs/unitree_h1_env.py +++ b/dial_mpc/envs/unitree_h1_env.py @@ -567,5 +567,340 @@ def randomize(): return state +@dataclass +class UnitreeH1LocoEnvConfig(BaseEnvConfig): + kp: Union[float, jax.Array] = field(default_factory=lambda: jnp.array( + [ + 200.0, + 200.0, + 200.0, # left hips + 200.0, + 60.0, # left knee, ankle + 200.0, + 200.0, + 200.0, # right hips + 200.0, + 60.0, # right knee, ankle + 200.0, # torso + ] + )) + kd: Union[float, jax.Array] = field(default_factory=lambda: jnp.array( + [ + 5.0, + 5.0, + 5.0, # left hips + 5.0, + 1.5, # left knee, ankle + 5.0, + 5.0, + 5.0, # right hips + 5.0, + 1.5, # right knee, ankle + 5.0, # torso + ] + )) + default_vx: float = 1.0 + default_vy: float = 0.0 + default_vyaw: float = 0.0 + ramp_up_time: float = 2.0 + gait: str = "jog" + + +class UnitreeH1LocoEnv(BaseEnv): + def __init__(self, config: UnitreeH1LocoEnvConfig): + super().__init__(config) + + # some body indices + self._pelvis_idx = mujoco.mj_name2id( + self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "pelvis" + ) + self._torso_idx = mujoco.mj_name2id( + self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "torso_link" + ) + + self._left_foot_idx = mujoco.mj_name2id( + self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, "left_foot" + ) + self._right_foot_idx = mujoco.mj_name2id( + self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, "right_foot" + ) + self._feet_site_id = jnp.array( + [self._left_foot_idx, self._right_foot_idx], dtype=jnp.int32 + ) + # gait phase + self._gait = config.gait + self._gait_phase = { + "stand": jnp.zeros(2), + "slow_walk": jnp.array([0.0, 0.5]), + "walk": jnp.array([0.0, 0.5]), + "jog": jnp.array([0.0, 0.5]), + } + self._gait_params = { + # ratio, cadence, amplitude + "stand": jnp.array([1.0, 1.0, 0.0]), + "slow_walk": jnp.array([0.6, 0.8, 0.15]), + "walk": jnp.array([0.5, 1.5, 0.10]), + "jog": jnp.array([0.3, 2.0, 0.2]), + } + + # joint limits and initial pose + self._init_q = jnp.array(self.sys.mj_model.keyframe("home").qpos) + self._default_pose = self.sys.mj_model.keyframe("home").qpos[7:] + # joint sampling range + self.joint_range = jnp.array( + [ + [-0.2, 0.2], + [-0.2, 0.2], + [-0.6, 0.6], + [0.0, 1.5], + [-0.6, 0.4], + + [-0.2, 0.2], + [-0.2, 0.2], + [-0.6, 0.6], + [0.0, 1.5], + [-0.6, 0.4], + + [-0.5, 0.5], + ] + ) + # self.joint_range = self.physical_joint_range + + def make_system(self, config: UnitreeH1LocoEnvConfig) -> System: + model_path = get_model_path("unitree_h1", "mjx_scene_h1_loco.xml") + sys = mjcf.load(model_path) + sys = sys.tree_replace({"opt.timestep": config.timestep}) + return sys + + def reset(self, rng: jax.Array) -> State: + rng, key = jax.random.split(rng) + + pipeline_state = self.pipeline_init(self._init_q, jnp.zeros(self._nv)) + + state_info = { + "rng": rng, + "pos_tar": jnp.array([0.0, 0.0, 1.3]), + "vel_tar": jnp.zeros(3), + "ang_vel_tar": jnp.zeros(3), + "yaw_tar": 0.0, + "step": 0, + "z_feet": jnp.zeros(2), + "z_feet_tar": jnp.zeros(2), + "randomize_target": self._config.randomize_tasks, + "last_contact": jnp.zeros(2, dtype=jnp.bool), + "feet_air_time": jnp.zeros(2), + } + + obs = self._get_obs(pipeline_state, state_info) + reward, done = jnp.zeros(2) + metrics = {} + state = State(pipeline_state, obs, reward, done, metrics, state_info) + return state + + def step( + self, state: State, action: jax.Array + ) -> State: # pytype: disable=signature-mismatch + rng, cmd_rng = jax.random.split(state.info["rng"], 2) + + # physics step + joint_targets = self.act2joint(action) + if self._config.leg_control == "position": + ctrl = joint_targets + elif self._config.leg_control == "torque": + ctrl = self.act2tau(action, state.pipeline_state) + pipeline_state = self.pipeline_step(state.pipeline_state, ctrl) + x, xd = pipeline_state.x, pipeline_state.xd + + # observation data + obs = self._get_obs(pipeline_state, state.info) + + # switch to new target if randomize_target is True + def dont_randomize(): + return ( + jnp.array([self._config.default_vx, self._config.default_vy, 0.0]), + jnp.array([0.0, 0.0, self._config.default_vyaw]), + ) + + def randomize(): + return self.sample_command(cmd_rng) + + vel_tar, ang_vel_tar = jax.lax.cond( + (state.info["randomize_target"]) & (state.info["step"] % 500 == 0), + randomize, + dont_randomize, + ) + state.info["vel_tar"] = jnp.minimum( + vel_tar * state.info["step"] * self.dt / self._config.ramp_up_time, vel_tar + ) + state.info["ang_vel_tar"] = jnp.minimum( + ang_vel_tar * state.info["step"] * self.dt / self._config.ramp_up_time, + ang_vel_tar, + ) + + # reward + # gaits reward + # z_feet = pipeline_state.site_xpos[self._feet_site_id][:, 2] + duty_ratio, cadence, amplitude = self._gait_params[self._gait] + phases = self._gait_phase[self._gait] + z_feet_tar = get_foot_step( + duty_ratio, cadence, amplitude, phases, state.info["step"] * self.dt + ) + # reward_gaits = -jnp.sum(((z_feet_tar - z_feet)) ** 2) + z_feet = jnp.array( + [ + jnp.min(pipeline_state.contact.dist[0:4]), + jnp.min(pipeline_state.contact.dist[4:8]), + ] + ) + # jax.debug.print("{contact_geom}", contact_geom=pipeline_state.contact.geom) + reward_gaits = -jnp.sum((z_feet_tar - z_feet) ** 2) + # foot contact data based on z-position + # pytype: disable=attribute-error + foot_pos = pipeline_state.site_xpos[self._feet_site_id] + foot_contact_z = foot_pos[:, 2] + contact = foot_contact_z < 1e-3 # a mm or less off the floor + contact_filt_mm = contact | state.info["last_contact"] + contact_filt_cm = (foot_contact_z < 3e-2) | state.info["last_contact"] + first_contact = (state.info["feet_air_time"] > 0) * contact_filt_mm + state.info["feet_air_time"] += self.dt + reward_air_time = jnp.sum((state.info["feet_air_time"] - 0.1) * first_contact) + # position reward + pos_tar = ( + state.info["pos_tar"] + state.info["vel_tar"] * self.dt * state.info["step"] + ) + pos = x.pos[self._torso_idx - 1] + reward_pos = -jnp.sum((pos - pos_tar) ** 2) + # stay upright reward + vec_tar = jnp.array([0.0, 0.0, 1.0]) + vec = math.rotate(vec_tar, x.rot[0]) + reward_upright = -jnp.sum(jnp.square(vec - vec_tar)) + # yaw orientation reward + yaw_tar = ( + state.info["yaw_tar"] + + state.info["ang_vel_tar"][2] * self.dt * state.info["step"] + ) + yaw = math.quat_to_euler(x.rot[self._torso_idx - 1])[2] + d_yaw = yaw - yaw_tar + reward_yaw = -jnp.square(jnp.atan2(jnp.sin(d_yaw), jnp.cos(d_yaw))) + # stay to norminal pose reward + # reward_pose = -jnp.sum(jnp.square(joint_targets - self._default_pose)) + # velocity reward + vb = global_to_body_velocity( + xd.vel[self._torso_idx - 1], x.rot[self._torso_idx - 1] + ) + ab = global_to_body_velocity( + xd.ang[self._torso_idx - 1] * jnp.pi / 180.0, x.rot[self._torso_idx - 1] + ) + reward_vel = -jnp.sum((vb[:2] - state.info["vel_tar"][:2]) ** 2) + reward_ang_vel = -jnp.sum((ab - state.info["ang_vel_tar"]) ** 2) + # height reward + reward_height = -jnp.sum( + (x.pos[self._torso_idx - 1, 2] - state.info["pos_tar"][2]) ** 2 + ) + # foot level reward + left_foot_mat = pipeline_state.site_xmat[self._left_foot_idx] + right_foot_mat = pipeline_state.site_xmat[self._right_foot_idx] + vec_left = left_foot_mat @ vec_tar + vec_right = right_foot_mat @ vec_tar + reward_foot_level = -jnp.sum((vec_left - vec_tar) ** 2 + (vec_right - vec_tar) ** 2) + # energy reward + reward_energy = -jnp.sum((ctrl / self.joint_torque_range[:, 1] * pipeline_state.qvel[6:6+len(self.joint_range)] / 160.0) ** 2) + #reward_energy = -jnp.sum((ctrl / self.joint_torque_range[:, 1]) ** 2) + # stay alive reward + reward_alive = 1.0 - state.done + # reward + reward = ( + reward_gaits * 10.0 + + reward_air_time * 0.0 + + reward_pos * 0.0 + + reward_upright * 0.5 + + reward_yaw * 0.5 + # + reward_pose * 0.0 + + reward_vel * 1.0 + + reward_ang_vel * 1.0 + + reward_height * 0.5 + + reward_foot_level * 0.02 + + reward_energy * 0.01 + + reward_alive * 0.0 + ) + + # done + up = jnp.array([0.0, 0.0, 1.0]) + joint_angles = pipeline_state.q[7:] + joint_angles = joint_angles[: len(self.joint_range)] + done = jnp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0 + done |= jnp.any(joint_angles < self.joint_range[:, 0]) + done |= jnp.any(joint_angles > self.joint_range[:, 1]) + done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18 + done = done.astype(jnp.float32) + + # state management + state.info["step"] += 1 + state.info["rng"] = rng + state.info["z_feet"] = z_feet + state.info["z_feet_tar"] = z_feet_tar + state.info["feet_air_time"] *= ~contact_filt_mm + state.info["last_contact"] = contact + + state = state.replace( + pipeline_state=pipeline_state, obs=obs, reward=reward, done=done + ) + return state + + def _get_obs( + self, + pipeline_state: base.State, + state_info: dict[str, Any], + ) -> jax.Array: + x, xd = pipeline_state.x, pipeline_state.xd + vb = global_to_body_velocity( + xd.vel[self._torso_idx - 1], x.rot[self._torso_idx - 1] + ) + ab = global_to_body_velocity( + xd.ang[self._torso_idx - 1] * jnp.pi / 180.0, x.rot[self._torso_idx - 1] + ) + obs = jnp.concatenate( + [ + state_info["vel_tar"], + state_info["ang_vel_tar"], + pipeline_state.ctrl, + pipeline_state.qpos, + vb, + ab, + pipeline_state.qvel[6:], + ] + ) + return obs + + def render( + self, + trajectory: List[base.State], + camera: str | None = None, + width: int = 240, + height: int = 320, + ) -> Sequence[np.ndarray]: + camera = camera or "track" + return super().render(trajectory, camera=camera, width=width, height=height) + + def sample_command(self, rng: jax.Array) -> tuple[jax.Array, jax.Array]: + lin_vel_x = [-1.5, 1.5] # min max [m/s] + lin_vel_y = [-0.5, 0.5] # min max [m/s] + ang_vel_yaw = [-1.5, 1.5] # min max [rad/s] + + _, key1, key2, key3 = jax.random.split(rng, 4) + lin_vel_x = jax.random.uniform( + key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1] + ) + lin_vel_y = jax.random.uniform( + key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1] + ) + ang_vel_yaw = jax.random.uniform( + key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1] + ) + new_lin_vel_cmd = jnp.array([lin_vel_x[0], lin_vel_y[0], 0.0]) + new_ang_vel_cmd = jnp.array([0.0, 0.0, ang_vel_yaw[0]]) + return new_lin_vel_cmd, new_ang_vel_cmd + brax_envs.register_environment("unitree_h1_walk", UnitreeH1WalkEnv) brax_envs.register_environment("unitree_h1_push_crate", UnitreeH1PushCrateEnv) +brax_envs.register_environment("unitree_h1_loco", UnitreeH1LocoEnv) diff --git a/dial_mpc/examples/__init__.py b/dial_mpc/examples/__init__.py index 70c7dee..f747a20 100644 --- a/dial_mpc/examples/__init__.py +++ b/dial_mpc/examples/__init__.py @@ -1,6 +1,7 @@ examples = [ "unitree_h1_jog", "unitree_h1_push_crate", + "unitree_h1_loco", "unitree_go2_trot", "unitree_go2_seq_jump", "unitree_go2_crate_climb", @@ -9,4 +10,5 @@ deploy_examples = [ "unitree_go2_trot_deploy", "unitree_go2_seq_jump_deploy", + "unitree_h1_loco_deploy", ] diff --git a/dial_mpc/examples/unitree_h1_loco.yaml b/dial_mpc/examples/unitree_h1_loco.yaml new file mode 100644 index 0000000..41ccaf4 --- /dev/null +++ b/dial_mpc/examples/unitree_h1_loco.yaml @@ -0,0 +1,28 @@ +# DIAL-MPC +seed: 0 +output_dir: unitree_h1_loco +n_steps: 400 + +env_name: unitree_h1_loco +Nsample: 2048 +Hsample: 20 +Hnode: 5 +Ndiffuse: 1 +Ndiffuse_init: 10 +temp_sample: 0.05 +horizon_diffuse_factor: 0.9 +traj_diffuse_factor: 0.2 +update_method: mppi + +# Base environment +dt: 0.02 +timestep: 0.02 +leg_control: torque +action_scale: 1.0 + +# H1 +default_vx: 0.6 +default_vy: 0.0 +default_vyaw: 0.0 +ramp_up_time: 3.0 +gait: walk diff --git a/dial_mpc/examples/unitree_h1_loco_deploy.yaml b/dial_mpc/examples/unitree_h1_loco_deploy.yaml new file mode 100644 index 0000000..a64ed94 --- /dev/null +++ b/dial_mpc/examples/unitree_h1_loco_deploy.yaml @@ -0,0 +1,38 @@ +# DIAL-MPC +seed: 0 +output_dir: unitree_h1_loco +n_steps: 400 + +env_name: unitree_h1_loco +Nsample: 4096 +Hsample: 20 +Hnode: 5 +Ndiffuse: 1 +Ndiffuse_init: 10 +temp_sample: 0.04 +horizon_diffuse_factor: 0.95 +traj_diffuse_factor: 0.2 +update_method: mppi + +# Base environment +dt: 0.02 +timestep: 0.02 +leg_control: torque +action_scale: 1.0 + +# H1 +default_vx: 0.8 +default_vy: 0.0 +default_vyaw: 0.0 +ramp_up_time: 3.0 +gait: walk + +# Sim +robot_name: "unitree_h1" +scene_name: "scene_h1_loco.xml" +sim_leg_control: torque +plot: false +record: false +real_time_factor: 1.0 +sim_dt: 0.005 +sync_mode: false diff --git a/dial_mpc/models/unitree_h1/h1_loco.xml b/dial_mpc/models/unitree_h1/h1_loco.xml new file mode 100644 index 0000000..f54b361 --- /dev/null +++ b/dial_mpc/models/unitree_h1/h1_loco.xml @@ -0,0 +1,232 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dial_mpc/models/unitree_h1/h1_real_feet.xml b/dial_mpc/models/unitree_h1/h1_real_feet.xml index e894417..d31cc98 100644 --- a/dial_mpc/models/unitree_h1/h1_real_feet.xml +++ b/dial_mpc/models/unitree_h1/h1_real_feet.xml @@ -1,8 +1,6 @@ - - @@ -13,25 +11,22 @@ - - - - - - - - + + + + + + - @@ -94,9 +89,8 @@ diaginertia="0.00220848 0.00218961 0.000214202"/> - - - + + @@ -134,9 +128,8 @@ diaginertia="0.00220848 0.00218961 0.000214202"/> - - - + + diff --git a/dial_mpc/models/unitree_h1/mjx_h1_loco.xml b/dial_mpc/models/unitree_h1/mjx_h1_loco.xml new file mode 100644 index 0000000..c3af483 --- /dev/null +++ b/dial_mpc/models/unitree_h1/mjx_h1_loco.xml @@ -0,0 +1,233 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dial_mpc/models/unitree_h1/mjx_scene_h1_loco.xml b/dial_mpc/models/unitree_h1/mjx_scene_h1_loco.xml new file mode 100644 index 0000000..813c945 --- /dev/null +++ b/dial_mpc/models/unitree_h1/mjx_scene_h1_loco.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dial_mpc/models/unitree_h1/scene_h1_loco.xml b/dial_mpc/models/unitree_h1/scene_h1_loco.xml new file mode 100644 index 0000000..452522a --- /dev/null +++ b/dial_mpc/models/unitree_h1/scene_h1_loco.xml @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file