Skip to content

Commit

Permalink
add h1 locomotion env
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoruXue committed Nov 6, 2024
1 parent 140f053 commit 7ef5cb6
Show file tree
Hide file tree
Showing 11 changed files with 947 additions and 18 deletions.
2 changes: 1 addition & 1 deletion dial_mpc/core/dial_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def reverse_scan(rng_Y0_state, factor):

t0 = time.time()
traj_diffuse_factors = (
dial_config.traj_diffuse_factor ** (jnp.arange(n_diffuse))[:, None]
mbdpi.sigma_control * dial_config.traj_diffuse_factor ** (jnp.arange(n_diffuse))[:, None]
)
(rng, Y0, _), info = jax.lax.scan(
reverse_scan, (rng, Y0, state), traj_diffuse_factors
Expand Down
2 changes: 2 additions & 0 deletions dial_mpc/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dial_mpc.envs.unitree_h1_env import (
UnitreeH1WalkEnvConfig,
UnitreeH1PushCrateEnvConfig,
UnitreeH1LocoEnvConfig,
)
from dial_mpc.envs.unitree_go2_env import (
UnitreeGo2EnvConfig,
Expand All @@ -12,6 +13,7 @@
_configs = {
"unitree_h1_walk": UnitreeH1WalkEnvConfig,
"unitree_h1_push_crate": UnitreeH1PushCrateEnvConfig,
"unitree_h1_loco": UnitreeH1LocoEnvConfig,
"unitree_go2_walk": UnitreeGo2EnvConfig,
"unitree_go2_seq_jump": UnitreeGo2SeqJumpEnvConfig,
"unitree_go2_crate_climb": UnitreeGo2CrateEnvConfig,
Expand Down
335 changes: 335 additions & 0 deletions dial_mpc/envs/unitree_h1_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,5 +567,340 @@ def randomize():
return state


@dataclass
class UnitreeH1LocoEnvConfig(BaseEnvConfig):
kp: Union[float, jax.Array] = field(default_factory=lambda: jnp.array(
[
200.0,
200.0,
200.0, # left hips
200.0,
60.0, # left knee, ankle
200.0,
200.0,
200.0, # right hips
200.0,
60.0, # right knee, ankle
200.0, # torso
]
))
kd: Union[float, jax.Array] = field(default_factory=lambda: jnp.array(
[
5.0,
5.0,
5.0, # left hips
5.0,
1.5, # left knee, ankle
5.0,
5.0,
5.0, # right hips
5.0,
1.5, # right knee, ankle
5.0, # torso
]
))
default_vx: float = 1.0
default_vy: float = 0.0
default_vyaw: float = 0.0
ramp_up_time: float = 2.0
gait: str = "jog"


class UnitreeH1LocoEnv(BaseEnv):
def __init__(self, config: UnitreeH1LocoEnvConfig):
super().__init__(config)

# some body indices
self._pelvis_idx = mujoco.mj_name2id(
self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "pelvis"
)
self._torso_idx = mujoco.mj_name2id(
self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, "torso_link"
)

self._left_foot_idx = mujoco.mj_name2id(
self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, "left_foot"
)
self._right_foot_idx = mujoco.mj_name2id(
self.sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, "right_foot"
)
self._feet_site_id = jnp.array(
[self._left_foot_idx, self._right_foot_idx], dtype=jnp.int32
)
# gait phase
self._gait = config.gait
self._gait_phase = {
"stand": jnp.zeros(2),
"slow_walk": jnp.array([0.0, 0.5]),
"walk": jnp.array([0.0, 0.5]),
"jog": jnp.array([0.0, 0.5]),
}
self._gait_params = {
# ratio, cadence, amplitude
"stand": jnp.array([1.0, 1.0, 0.0]),
"slow_walk": jnp.array([0.6, 0.8, 0.15]),
"walk": jnp.array([0.5, 1.5, 0.10]),
"jog": jnp.array([0.3, 2.0, 0.2]),
}

# joint limits and initial pose
self._init_q = jnp.array(self.sys.mj_model.keyframe("home").qpos)
self._default_pose = self.sys.mj_model.keyframe("home").qpos[7:]
# joint sampling range
self.joint_range = jnp.array(
[
[-0.2, 0.2],
[-0.2, 0.2],
[-0.6, 0.6],
[0.0, 1.5],
[-0.6, 0.4],

[-0.2, 0.2],
[-0.2, 0.2],
[-0.6, 0.6],
[0.0, 1.5],
[-0.6, 0.4],

[-0.5, 0.5],
]
)
# self.joint_range = self.physical_joint_range

def make_system(self, config: UnitreeH1LocoEnvConfig) -> System:
model_path = get_model_path("unitree_h1", "mjx_scene_h1_loco.xml")
sys = mjcf.load(model_path)
sys = sys.tree_replace({"opt.timestep": config.timestep})
return sys

def reset(self, rng: jax.Array) -> State:
rng, key = jax.random.split(rng)

pipeline_state = self.pipeline_init(self._init_q, jnp.zeros(self._nv))

state_info = {
"rng": rng,
"pos_tar": jnp.array([0.0, 0.0, 1.3]),
"vel_tar": jnp.zeros(3),
"ang_vel_tar": jnp.zeros(3),
"yaw_tar": 0.0,
"step": 0,
"z_feet": jnp.zeros(2),
"z_feet_tar": jnp.zeros(2),
"randomize_target": self._config.randomize_tasks,
"last_contact": jnp.zeros(2, dtype=jnp.bool),
"feet_air_time": jnp.zeros(2),
}

obs = self._get_obs(pipeline_state, state_info)
reward, done = jnp.zeros(2)
metrics = {}
state = State(pipeline_state, obs, reward, done, metrics, state_info)
return state

def step(
self, state: State, action: jax.Array
) -> State: # pytype: disable=signature-mismatch
rng, cmd_rng = jax.random.split(state.info["rng"], 2)

# physics step
joint_targets = self.act2joint(action)
if self._config.leg_control == "position":
ctrl = joint_targets
elif self._config.leg_control == "torque":
ctrl = self.act2tau(action, state.pipeline_state)
pipeline_state = self.pipeline_step(state.pipeline_state, ctrl)
x, xd = pipeline_state.x, pipeline_state.xd

# observation data
obs = self._get_obs(pipeline_state, state.info)

# switch to new target if randomize_target is True
def dont_randomize():
return (
jnp.array([self._config.default_vx, self._config.default_vy, 0.0]),
jnp.array([0.0, 0.0, self._config.default_vyaw]),
)

def randomize():
return self.sample_command(cmd_rng)

vel_tar, ang_vel_tar = jax.lax.cond(
(state.info["randomize_target"]) & (state.info["step"] % 500 == 0),
randomize,
dont_randomize,
)
state.info["vel_tar"] = jnp.minimum(
vel_tar * state.info["step"] * self.dt / self._config.ramp_up_time, vel_tar
)
state.info["ang_vel_tar"] = jnp.minimum(
ang_vel_tar * state.info["step"] * self.dt / self._config.ramp_up_time,
ang_vel_tar,
)

# reward
# gaits reward
# z_feet = pipeline_state.site_xpos[self._feet_site_id][:, 2]
duty_ratio, cadence, amplitude = self._gait_params[self._gait]
phases = self._gait_phase[self._gait]
z_feet_tar = get_foot_step(
duty_ratio, cadence, amplitude, phases, state.info["step"] * self.dt
)
# reward_gaits = -jnp.sum(((z_feet_tar - z_feet)) ** 2)
z_feet = jnp.array(
[
jnp.min(pipeline_state.contact.dist[0:4]),
jnp.min(pipeline_state.contact.dist[4:8]),
]
)
# jax.debug.print("{contact_geom}", contact_geom=pipeline_state.contact.geom)
reward_gaits = -jnp.sum((z_feet_tar - z_feet) ** 2)
# foot contact data based on z-position
# pytype: disable=attribute-error
foot_pos = pipeline_state.site_xpos[self._feet_site_id]
foot_contact_z = foot_pos[:, 2]
contact = foot_contact_z < 1e-3 # a mm or less off the floor
contact_filt_mm = contact | state.info["last_contact"]
contact_filt_cm = (foot_contact_z < 3e-2) | state.info["last_contact"]
first_contact = (state.info["feet_air_time"] > 0) * contact_filt_mm
state.info["feet_air_time"] += self.dt
reward_air_time = jnp.sum((state.info["feet_air_time"] - 0.1) * first_contact)
# position reward
pos_tar = (
state.info["pos_tar"] + state.info["vel_tar"] * self.dt * state.info["step"]
)
pos = x.pos[self._torso_idx - 1]
reward_pos = -jnp.sum((pos - pos_tar) ** 2)
# stay upright reward
vec_tar = jnp.array([0.0, 0.0, 1.0])
vec = math.rotate(vec_tar, x.rot[0])
reward_upright = -jnp.sum(jnp.square(vec - vec_tar))
# yaw orientation reward
yaw_tar = (
state.info["yaw_tar"]
+ state.info["ang_vel_tar"][2] * self.dt * state.info["step"]
)
yaw = math.quat_to_euler(x.rot[self._torso_idx - 1])[2]
d_yaw = yaw - yaw_tar
reward_yaw = -jnp.square(jnp.atan2(jnp.sin(d_yaw), jnp.cos(d_yaw)))
# stay to norminal pose reward
# reward_pose = -jnp.sum(jnp.square(joint_targets - self._default_pose))
# velocity reward
vb = global_to_body_velocity(
xd.vel[self._torso_idx - 1], x.rot[self._torso_idx - 1]
)
ab = global_to_body_velocity(
xd.ang[self._torso_idx - 1] * jnp.pi / 180.0, x.rot[self._torso_idx - 1]
)
reward_vel = -jnp.sum((vb[:2] - state.info["vel_tar"][:2]) ** 2)
reward_ang_vel = -jnp.sum((ab - state.info["ang_vel_tar"]) ** 2)
# height reward
reward_height = -jnp.sum(
(x.pos[self._torso_idx - 1, 2] - state.info["pos_tar"][2]) ** 2
)
# foot level reward
left_foot_mat = pipeline_state.site_xmat[self._left_foot_idx]
right_foot_mat = pipeline_state.site_xmat[self._right_foot_idx]
vec_left = left_foot_mat @ vec_tar
vec_right = right_foot_mat @ vec_tar
reward_foot_level = -jnp.sum((vec_left - vec_tar) ** 2 + (vec_right - vec_tar) ** 2)
# energy reward
reward_energy = -jnp.sum((ctrl / self.joint_torque_range[:, 1] * pipeline_state.qvel[6:6+len(self.joint_range)] / 160.0) ** 2)
#reward_energy = -jnp.sum((ctrl / self.joint_torque_range[:, 1]) ** 2)
# stay alive reward
reward_alive = 1.0 - state.done
# reward
reward = (
reward_gaits * 10.0
+ reward_air_time * 0.0
+ reward_pos * 0.0
+ reward_upright * 0.5
+ reward_yaw * 0.5
# + reward_pose * 0.0
+ reward_vel * 1.0
+ reward_ang_vel * 1.0
+ reward_height * 0.5
+ reward_foot_level * 0.02
+ reward_energy * 0.01
+ reward_alive * 0.0
)

# done
up = jnp.array([0.0, 0.0, 1.0])
joint_angles = pipeline_state.q[7:]
joint_angles = joint_angles[: len(self.joint_range)]
done = jnp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0
done |= jnp.any(joint_angles < self.joint_range[:, 0])
done |= jnp.any(joint_angles > self.joint_range[:, 1])
done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18
done = done.astype(jnp.float32)

# state management
state.info["step"] += 1
state.info["rng"] = rng
state.info["z_feet"] = z_feet
state.info["z_feet_tar"] = z_feet_tar
state.info["feet_air_time"] *= ~contact_filt_mm
state.info["last_contact"] = contact

state = state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)
return state

def _get_obs(
self,
pipeline_state: base.State,
state_info: dict[str, Any],
) -> jax.Array:
x, xd = pipeline_state.x, pipeline_state.xd
vb = global_to_body_velocity(
xd.vel[self._torso_idx - 1], x.rot[self._torso_idx - 1]
)
ab = global_to_body_velocity(
xd.ang[self._torso_idx - 1] * jnp.pi / 180.0, x.rot[self._torso_idx - 1]
)
obs = jnp.concatenate(
[
state_info["vel_tar"],
state_info["ang_vel_tar"],
pipeline_state.ctrl,
pipeline_state.qpos,
vb,
ab,
pipeline_state.qvel[6:],
]
)
return obs

def render(
self,
trajectory: List[base.State],
camera: str | None = None,
width: int = 240,
height: int = 320,
) -> Sequence[np.ndarray]:
camera = camera or "track"
return super().render(trajectory, camera=camera, width=width, height=height)

def sample_command(self, rng: jax.Array) -> tuple[jax.Array, jax.Array]:
lin_vel_x = [-1.5, 1.5] # min max [m/s]
lin_vel_y = [-0.5, 0.5] # min max [m/s]
ang_vel_yaw = [-1.5, 1.5] # min max [rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)
lin_vel_x = jax.random.uniform(
key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]
)
lin_vel_y = jax.random.uniform(
key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]
)
ang_vel_yaw = jax.random.uniform(
key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]
)
new_lin_vel_cmd = jnp.array([lin_vel_x[0], lin_vel_y[0], 0.0])
new_ang_vel_cmd = jnp.array([0.0, 0.0, ang_vel_yaw[0]])
return new_lin_vel_cmd, new_ang_vel_cmd

brax_envs.register_environment("unitree_h1_walk", UnitreeH1WalkEnv)
brax_envs.register_environment("unitree_h1_push_crate", UnitreeH1PushCrateEnv)
brax_envs.register_environment("unitree_h1_loco", UnitreeH1LocoEnv)
2 changes: 2 additions & 0 deletions dial_mpc/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
examples = [
"unitree_h1_jog",
"unitree_h1_push_crate",
"unitree_h1_loco",
"unitree_go2_trot",
"unitree_go2_seq_jump",
"unitree_go2_crate_climb",
Expand All @@ -9,4 +10,5 @@
deploy_examples = [
"unitree_go2_trot_deploy",
"unitree_go2_seq_jump_deploy",
"unitree_h1_loco_deploy",
]
Loading

0 comments on commit 7ef5cb6

Please sign in to comment.