Skip to content

Commit

Permalink
changes for real deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
rogeriobonatti committed Jun 6, 2022
1 parent b4be28c commit f94d1b9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mushr_rhc_ros/launch/sim/sim_server_eval.launch
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<!-- <arg name="model_path_act" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/model_sizes_0/GPTcorl_scratch_trainm_e2e_statet_pointnet_traini_1_nla_24_nhe_8_statel_0.01_2022-06-03_1654235683.085602_2022-06-03_1654235683.0856125/model/epoch30.pth.tar" /> -->

<!-- map model -->
<arg name="use_map" default="true" />
<arg name="use_map" default="false" />
<!-- without fine-tuning, 100% of data -->
<!-- <arg name="model_path_map" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/nofinetune_episodes_map_0/GPTcorl_map_trainm_map_sta_pointnet_traini_1_nla_12_nhe_8_2022-06-03_1654271262.2364998_2022-06-03_1654271262.2365131/model/epoch30.pth.tar" /> -->

Expand Down
33 changes: 23 additions & 10 deletions mushr_rhc_ros/src/rhcnode_network_pcl_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def __init__(self, dtype, params, logger, name):
self.device = device
self.clip_len = 16

self.is_real_deployment = rospy.get_param("~is_real_deployment", False)

self.map_type = rospy.get_param("~deployment_map", 'train')

self.use_map = rospy.get_param("~use_map", False)
Expand Down Expand Up @@ -186,7 +188,7 @@ def __init__(self, dtype, params, logger, name):
map_model.to(device)
self.map_model = map_model
rate_map_display = 1.0
self.map_viz_timer = rospy.Timer(rospy.Duration(1.0 / rate_map_display), self.map_viz_cb)


# localization model
if self.use_loc:
Expand Down Expand Up @@ -218,7 +220,7 @@ def __init__(self, dtype, params, logger, name):
loc_model.to(device)
self.loc_model = loc_model
rate_loc_display = 20
self.map_viz_loc = rospy.Timer(rospy.Duration(1.0 / rate_loc_display), self.loc_viz_cb)



self.small_queue = Queue(maxsize = self.clip_len) # stores current scan, action, pose. up to 16 elements
Expand All @@ -238,6 +240,12 @@ def __init__(self, dtype, params, logger, name):
self.time_so_far = 0.0
self.file_name = os.path.join(self.out_path,'info.csv')

# define timer callbacks:
if self.use_map:
self.map_viz_loc = rospy.Timer(rospy.Duration(1.0 / rate_loc_display), self.loc_viz_cb)
if self.use_loc:
self.map_viz_timer = rospy.Timer(rospy.Duration(1.0 / rate_map_display), self.map_viz_cb)




Expand Down Expand Up @@ -272,8 +280,9 @@ def start(self):
while not rospy.is_shutdown() and self.run:

# check if we should reset the vehicle if crashed
if self.check_reset(rate_hz):
rospy.loginfo("Resetting the car's position")
if not self.is_real_deployment:
if self.check_reset(rate_hz):
rospy.loginfo("Resetting the car's position")

# publish next action
if self.compute_network_action:
Expand All @@ -295,16 +304,20 @@ def map_viz_cb(self, timer):
pos_queue_list = list(self.small_queue.queue)
self.small_queue_lock.release()
pos_size = len(pos_queue_list)
pose_mid = pos_queue_list[int(pos_size/2) -1][2]


if pos_size==16:
pose_mid = pos_queue_list[int(pos_size/2) -1][2]
# if not self.is_real_deployment:
# pose_mid = pos_queue_list[int(pos_size/2) -1][2]
# else:
# pose_mid = PoseStamped()
x_imgs, x_act, t = self.prepare_model_inputs(queue_type='small')
start = time.time()
# with torch.set_grad_enabled(False):
with torch.inference_mode():
self.map_recon, _ = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False)
finished_map_network = time.time()
# rospy.loginfo("map network delay: "+str(finished_map_network-start))
rospy.loginfo("map network delay: "+str(finished_map_network-start))

# publish the GT pose of the map center
self.pose_marker_pub.publish(self.create_position_marker(pose_mid))
Expand Down Expand Up @@ -351,7 +364,7 @@ def loc_viz_cb(self, timer):
with torch.inference_mode():
pose_preds, _, _, _, _ = self.loc_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False)
finished_loc_network = time.time()
# rospy.loginfo("loc network delay: "+str(finished_loc_network-start))
rospy.loginfo("loc network delay: "+str(finished_loc_network-start))
# publish anchor pose of the map center
self.loc_anchor_pose_marker_pub.publish(self.create_position_marker(self.pose_anchor, color=[0,0,1,1]))
# publish the current accumulated pose
Expand Down Expand Up @@ -526,8 +539,8 @@ def apply_network(self):
# action_pred = 0.0
action_pred, _, _, _ = self.model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False)
finished_action_network = time.time()
# rospy.loginfo("action network delay: "+str(finished_action_network-start))
rospy.loginfo_throttle(10, "action network delay: "+str(finished_action_network-start))
rospy.loginfo("action network delay: "+str(finished_action_network-start))
# rospy.loginfo_throttle(10, "action network delay: "+str(finished_action_network-start))
action_pred = action_pred[0,-1,0].cpu().flatten().item()
# if self.use_map:
# map_pred, loss = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None)
Expand Down

0 comments on commit f94d1b9

Please sign in to comment.