diff --git a/mushr_rhc_ros/launch/sim/sim_server_eval.launch b/mushr_rhc_ros/launch/sim/sim_server_eval.launch
index 71d7f8a..ca45e59 100644
--- a/mushr_rhc_ros/launch/sim/sim_server_eval.launch
+++ b/mushr_rhc_ros/launch/sim/sim_server_eval.launch
@@ -10,9 +10,13 @@
-
+
+
+
+
+
@@ -26,7 +30,7 @@
-
+
diff --git a/mushr_rhc_ros/src/rhcnode_network_pcl_new.py b/mushr_rhc_ros/src/rhcnode_network_pcl_new.py
index 5ba14fd..f9321d5 100755
--- a/mushr_rhc_ros/src/rhcnode_network_pcl_new.py
+++ b/mushr_rhc_ros/src/rhcnode_network_pcl_new.py
@@ -70,6 +70,9 @@ def __init__(self, dtype, params, logger, name):
self.time_started_goal = None
self.num_trials = 0
+ self.act_inference_time_sum = 0.0
+ self.act_inference_time_count = 0
+
self.cur_rollout = self.cur_rollout_ip = None
self.traj_pub_lock = threading.Lock()
@@ -107,7 +110,7 @@ def __init__(self, dtype, params, logger, name):
self.use_map = rospy.get_param("~use_map", False)
self.use_loc = rospy.get_param("~use_loc", False)
- saved_model_path_action = rospy.get_param("~model_path_act", 'default_value')
+ saved_model_path_action = rospy.get_param("~model_path_act", '')
self.out_path = rospy.get_param("~out_path", 'default_value')
self.n_layers = rospy.get_param("~n_layers", 12)
@@ -127,19 +130,20 @@ def __init__(self, dtype, params, logger, name):
model = GPT(mconf, device)
# model=torch.nn.DataParallel(model)
- checkpoint = torch.load(saved_model_path_action, map_location=device)
- # old code for loading model
- model.load_state_dict(checkpoint['state_dict'])
- # new code for loading mode
- # new_checkpoint = OrderedDict()
- # for key in checkpoint['state_dict'].keys():
- # new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key]
- # model.load_state_dict(new_checkpoint)
+ if len(saved_model_path_action)>0:
+ checkpoint = torch.load(saved_model_path_action, map_location=device)
+ # old code for loading model
+ model.load_state_dict(checkpoint['state_dict'])
+ # new code for loading mode
+ # new_checkpoint = OrderedDict()
+ # for key in checkpoint['state_dict'].keys():
+ # new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key]
+ # model.load_state_dict(new_checkpoint)
- # ckpt = torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict']
- # for key in ckpt:
- # print('********',key)
- # model.load_state_dict(torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'], strict=True)
+ # ckpt = torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict']
+ # for key in ckpt:
+ # print('********',key)
+ # model.load_state_dict(torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'], strict=True)
model.eval()
# model.half()
@@ -532,6 +536,7 @@ def publish_vel_marker(self):
def apply_network(self):
+ start_zero = time.time()
x_imgs, x_act, t = self.prepare_model_inputs(queue_type='small')
start = time.time()
# with torch.set_grad_enabled(False):
@@ -539,15 +544,18 @@ def apply_network(self):
# action_pred = 0.0
action_pred, _, _, _ = self.model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False)
finished_action_network = time.time()
- rospy.loginfo("action network delay: "+str(finished_action_network-start))
- # rospy.loginfo_throttle(10, "action network delay: "+str(finished_action_network-start))
+ # rospy.loginfo("action network delay: "+str(finished_action_network-start))
+ # self.act_inference_time_sum += finished_action_network-start
+ # self.act_inference_time_count += 1
+ rospy.loginfo_throttle(10, "action network delay: "+str(finished_action_network-start))
+ # rospy.loginfo_throttle(10, "AVG action network delay: "+str(self.act_inference_time_sum/self.act_inference_time_count))
action_pred = action_pred[0,-1,0].cpu().flatten().item()
# if self.use_map:
# map_pred, loss = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None)
# finished_map_network = time.time()
# rospy.loginfo("map network delay: "+str(finished_map_network-finished_action_network))
finished_network = time.time()
- # rospy.loginfo("network delay: "+str(finished_network-finish_processing))
+ # rospy.loginfo("network delay total: "+str(finished_network-start_zero))
# de-normalize
action_pred = pre.denorm_angle(action_pred)
return action_pred