diff --git a/mj_envs/hand_manipulation_suite/door_v0.py b/mj_envs/hand_manipulation_suite/door_v0.py index 061707cb..2eabb0d5 100644 --- a/mj_envs/hand_manipulation_suite/door_v0.py +++ b/mj_envs/hand_manipulation_suite/door_v0.py @@ -57,6 +57,19 @@ def _step(self, a): return ob, reward, False, {} + def get_proprioception(self, use_tactile): + # return self._get_obs() + robot_jnt = self.data.qpos.ravel()[:-2] + robot_vel = self.data.qvel.ravel()[:-2] + palm_pos = self.data.site_xpos[self.grasp_sid].ravel() + sensordata = [] + if use_tactile: + sensordata = self.data.sensordata.ravel().copy()[:41] + sensordata = np.clip(sensordata, -5.0, 5.0) + + res = np.concatenate([robot_jnt, robot_vel, palm_pos, sensordata]) + return res + def _get_obs(self): # qpos for hand # xpos for obj diff --git a/mj_envs/hand_manipulation_suite/hammer_v0.py b/mj_envs/hand_manipulation_suite/hammer_v0.py index 42ba06f0..3ff72a3c 100644 --- a/mj_envs/hand_manipulation_suite/hammer_v0.py +++ b/mj_envs/hand_manipulation_suite/hammer_v0.py @@ -78,6 +78,19 @@ def _get_obs(self): nail_impact = np.clip(self.sim.data.sensordata[self.sim.model.sensor_name2id('S_nail')], -1.0, 1.0) return np.concatenate([qp[:-6], qv[-6:], palm_pos, obj_pos, obj_rot, target_pos, np.array([nail_impact])]) + def get_proprioception(self, use_tactile): + # return self._get_obs() + robot_jnt = self.data.qpos.ravel()[:-6] + robot_vel = self.data.qvel.ravel()[:-6] + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + sensordata = [] + if use_tactile: + sensordata = self.data.sensordata.ravel().copy()[:41] + sensordata = np.clip(sensordata, -5.0, 5.0) + + res = np.concatenate([robot_jnt, robot_vel, palm_pos, sensordata]) + return res + def reset_model(self): self.sim.reset() target_bid = self.model.body_name2id('nail_board') diff --git a/mj_envs/hand_manipulation_suite/pen_v0.py b/mj_envs/hand_manipulation_suite/pen_v0.py index 8cceba60..f7480884 100644 --- a/mj_envs/hand_manipulation_suite/pen_v0.py +++ b/mj_envs/hand_manipulation_suite/pen_v0.py @@ -85,6 +85,18 @@ def _get_obs(self): return np.concatenate([qp[:-6], obj_pos, obj_vel, obj_orien, desired_orien, obj_pos-desired_pos, obj_orien-desired_orien]) + def get_proprioception(self, use_tactile): + # return self._get_obs() + robot_jnt = self.data.qpos.ravel()[:-6] + robot_vel = self.data.qvel.ravel()[:-6] + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + sensordata = [] + if use_tactile: + sensordata = self.data.sensordata.ravel().copy()[20:41] + + res = np.concatenate([robot_jnt, robot_vel, palm_pos, sensordata]) + return res + def reset_model(self): qp = self.init_qpos.copy() qv = self.init_qvel.copy() diff --git a/mj_envs/hand_manipulation_suite/relocate_v0.py b/mj_envs/hand_manipulation_suite/relocate_v0.py index f1c351c5..6381ec14 100644 --- a/mj_envs/hand_manipulation_suite/relocate_v0.py +++ b/mj_envs/hand_manipulation_suite/relocate_v0.py @@ -58,7 +58,17 @@ def _get_obs(self): palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() target_pos = self.data.site_xpos[self.target_obj_sid].ravel() return np.concatenate([qp[:-6], palm_pos-obj_pos, palm_pos-target_pos, obj_pos-target_pos]) - + + def get_proprioception(self, use_tactile): + # return self._get_obs() + robot_pos = self.data.qpos.ravel()[:-6] + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + sensordata = [] + if use_tactile: + sensordata = self.data.sensordata.ravel().copy()[20:41] + res = np.concatenate([robot_pos, palm_pos, sensordata]) + return res + def reset_model(self): qp = self.init_qpos.copy() qv = self.init_qvel.copy() diff --git a/mj_envs/mujoco_env.py b/mj_envs/mujoco_env.py index 9425e802..8290ad3c 100644 --- a/mj_envs/mujoco_env.py +++ b/mj_envs/mujoco_env.py @@ -81,8 +81,13 @@ def viewer_setup(self): """ pass - # ----------------------------- + def get_proprioception(self, use_tactile): + """ + For the VIL paper + """ + pass + # ----------------------------- def _reset(self): self.sim.reset() self.sim.forward() @@ -129,6 +134,12 @@ def state_vector(self): state.qpos.flat, state.qvel.flat]) # ----------------------------- + def get_pixels(self, frame_size=(128, 128), camera_name=None, device_id=0): + pixels = self.sim.render(width=frame_size[0], height=frame_size[1], + mode='offscreen', camera_name=camera_name, device_id=device_id) + + pixels = pixels[::-1, :, :] + return pixels def visualize_policy(self, policy, horizon=1000, num_episodes=1, mode='exploration'): self.mujoco_render_frames = True