diff --git a/mushr_rhc_ros/launch/sim/sim_server_eval.launch b/mushr_rhc_ros/launch/sim/sim_server_eval.launch index 71d7f8a..ca45e59 100644 --- a/mushr_rhc_ros/launch/sim/sim_server_eval.launch +++ b/mushr_rhc_ros/launch/sim/sim_server_eval.launch @@ -10,9 +10,13 @@ - + + + + + @@ -26,7 +30,7 @@ - + diff --git a/mushr_rhc_ros/src/rhcnode_network_pcl_new.py b/mushr_rhc_ros/src/rhcnode_network_pcl_new.py index 5ba14fd..f9321d5 100755 --- a/mushr_rhc_ros/src/rhcnode_network_pcl_new.py +++ b/mushr_rhc_ros/src/rhcnode_network_pcl_new.py @@ -70,6 +70,9 @@ def __init__(self, dtype, params, logger, name): self.time_started_goal = None self.num_trials = 0 + self.act_inference_time_sum = 0.0 + self.act_inference_time_count = 0 + self.cur_rollout = self.cur_rollout_ip = None self.traj_pub_lock = threading.Lock() @@ -107,7 +110,7 @@ def __init__(self, dtype, params, logger, name): self.use_map = rospy.get_param("~use_map", False) self.use_loc = rospy.get_param("~use_loc", False) - saved_model_path_action = rospy.get_param("~model_path_act", 'default_value') + saved_model_path_action = rospy.get_param("~model_path_act", '') self.out_path = rospy.get_param("~out_path", 'default_value') self.n_layers = rospy.get_param("~n_layers", 12) @@ -127,19 +130,20 @@ def __init__(self, dtype, params, logger, name): model = GPT(mconf, device) # model=torch.nn.DataParallel(model) - checkpoint = torch.load(saved_model_path_action, map_location=device) - # old code for loading model - model.load_state_dict(checkpoint['state_dict']) - # new code for loading mode - # new_checkpoint = OrderedDict() - # for key in checkpoint['state_dict'].keys(): - # new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key] - # model.load_state_dict(new_checkpoint) + if len(saved_model_path_action)>0: + checkpoint = torch.load(saved_model_path_action, map_location=device) + # old code for loading model + model.load_state_dict(checkpoint['state_dict']) + # new code for loading mode + # new_checkpoint = OrderedDict() + # for key in checkpoint['state_dict'].keys(): + # new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key] + # model.load_state_dict(new_checkpoint) - # ckpt = torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'] - # for key in ckpt: - # print('********',key) - # model.load_state_dict(torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'], strict=True) + # ckpt = torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'] + # for key in ckpt: + # print('********',key) + # model.load_state_dict(torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'], strict=True) model.eval() # model.half() @@ -532,6 +536,7 @@ def publish_vel_marker(self): def apply_network(self): + start_zero = time.time() x_imgs, x_act, t = self.prepare_model_inputs(queue_type='small') start = time.time() # with torch.set_grad_enabled(False): @@ -539,15 +544,18 @@ def apply_network(self): # action_pred = 0.0 action_pred, _, _, _ = self.model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) finished_action_network = time.time() - rospy.loginfo("action network delay: "+str(finished_action_network-start)) - # rospy.loginfo_throttle(10, "action network delay: "+str(finished_action_network-start)) + # rospy.loginfo("action network delay: "+str(finished_action_network-start)) + # self.act_inference_time_sum += finished_action_network-start + # self.act_inference_time_count += 1 + rospy.loginfo_throttle(10, "action network delay: "+str(finished_action_network-start)) + # rospy.loginfo_throttle(10, "AVG action network delay: "+str(self.act_inference_time_sum/self.act_inference_time_count)) action_pred = action_pred[0,-1,0].cpu().flatten().item() # if self.use_map: # map_pred, loss = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None) # finished_map_network = time.time() # rospy.loginfo("map network delay: "+str(finished_map_network-finished_action_network)) finished_network = time.time() - # rospy.loginfo("network delay: "+str(finished_network-finish_processing)) + # rospy.loginfo("network delay total: "+str(finished_network-start_zero)) # de-normalize action_pred = pre.denorm_angle(action_pred) return action_pred