From 043001b179893a83da34445e7fd5f960cd0d3b72 Mon Sep 17 00:00:00 2001 From: Super1ce <32703938+Super1ce@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:51:10 +0800 Subject: [PATCH] polish(zc): change PD config name (#749) * add action * change entry --- ding/policy/plan_diffuser.py | 2 +- ding/utils/data/dataset.py | 2 ++ dizoo/d4rl/config/antmaze_umaze_pd_config.py | 2 +- dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py | 2 +- dizoo/d4rl/config/halfcheetah_medium_pd_config.py | 2 +- dizoo/d4rl/config/hopper_medium_expert_pd_config.py | 2 +- dizoo/d4rl/config/hopper_medium_pd_config.py | 2 +- dizoo/d4rl/config/walker2d_medium_expert_pd_config.py | 2 +- dizoo/d4rl/config/walker2d_medium_pd_config.py | 2 +- dizoo/d4rl/entry/d4rl_pd_main.py | 2 +- 10 files changed, 11 insertions(+), 9 deletions(-) diff --git a/ding/policy/plan_diffuser.py b/ding/policy/plan_diffuser.py index 7e6854789f..ad58546a15 100755 --- a/ding/policy/plan_diffuser.py +++ b/ding/policy/plan_diffuser.py @@ -178,7 +178,7 @@ def _init_learn(self) -> None: self.step_start_update_target = self._cfg.learn.step_start_update_target self.target_weight = self._cfg.learn.target_weight self.value_step = self._cfg.learn.value_step - self.use_target = True + self.use_target = False self.horizon = self._cfg.model.diffuser_model_cfg.horizon self.include_returns = self._cfg.learn.include_returns diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 013f8b688d..23db0fcdf9 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1090,11 +1090,13 @@ def __getitem__(self, idx, eps=1e-4): 'trajectories': trajectories, 'returns': returns, 'done': done, + 'action': actions, } else: batch = { 'trajectories': trajectories, 'done': done, + 'action': actions, } batch.update(self.get_conditions(observations)) diff --git a/dizoo/d4rl/config/antmaze_umaze_pd_config.py b/dizoo/d4rl/config/antmaze_umaze_pd_config.py index 8dadd63a13..96ca022545 100755 --- a/dizoo/d4rl/config/antmaze_umaze_pd_config.py +++ b/dizoo/d4rl/config/antmaze_umaze_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=37, dim=32, diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py index d3a0dbb4b8..66c8ba8d91 100755 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/config/halfcheetah_medium_pd_config.py b/dizoo/d4rl/config/halfcheetah_medium_pd_config.py index 2386c278ec..674395a4e1 100755 --- a/dizoo/d4rl/config/halfcheetah_medium_pd_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py index 6205018751..3df47f8d1b 100755 --- a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=14, dim=32, diff --git a/dizoo/d4rl/config/hopper_medium_pd_config.py b/dizoo/d4rl/config/hopper_medium_pd_config.py index 49caaec5d2..8dfee5d824 100755 --- a/dizoo/d4rl/config/hopper_medium_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=14, dim=32, diff --git a/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py b/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py index 18cb45559b..3d4c060e83 100755 --- a/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/config/walker2d_medium_pd_config.py b/dizoo/d4rl/config/walker2d_medium_pd_config.py index 8b2c0b0a4a..29fce259c8 100755 --- a/dizoo/d4rl/config/walker2d_medium_pd_config.py +++ b/dizoo/d4rl/config/walker2d_medium_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/entry/d4rl_pd_main.py b/dizoo/d4rl/entry/d4rl_pd_main.py index 73e08288ed..1ca3c5b299 100755 --- a/dizoo/d4rl/entry/d4rl_pd_main.py +++ b/dizoo/d4rl/entry/d4rl_pd_main.py @@ -16,6 +16,6 @@ def train(args): parser = argparse.ArgumentParser() parser.add_argument('--seed', '-s', type=int, default=10) - parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py') + parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py') args = parser.parse_args() train(args) \ No newline at end of file