Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rogeriobonatti committed Jun 5, 2022
1 parent 198fda6 commit 4323728
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 38 deletions.
19 changes: 19 additions & 0 deletions mushr_rhc_ros/launch/map_server_test_env.launch
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<launch>
<!-- <arg name="map" default="$(find mushr_rhc_ros)/maps/real-floor0-edited.yaml" /> -->
<!-- <arg name="c2g_map" default="$(find mushr_rhc_ros)/maps/real-floor0-edited-c2g-full-loop.yaml" /> -->

<!-- <arg name="map" default="$(find mushr_rhc_ros)/maps/bravern_13_mod.yaml" />
<arg name="c2g_map" default="$(find mushr_rhc_ros)/maps/bravern_13_mod.yaml" /> -->

<arg name="map" default="$(find mushr_rhc_ros)/maps/map_12f_cropped.yaml" />
<arg name="c2g_map" default="$(find mushr_rhc_ros)/maps/map_12f_cropped.yaml" />

<node pkg="map_server" name="map_server" type="map_server" args="$(arg map)"/>
<node pkg="map_server" name="c2g_map" type="map_server" args="$(arg c2g_map)">
<remap from="/static_map" to="/c2g/static_map" />
<remap from="/map_metadata" to="/c2g/map_metadata" />
<remap from="/map" to="/c2g/map" />
</node>

<param name="map_file" value="$(arg c2g_map)" />
</launch>
44 changes: 25 additions & 19 deletions mushr_rhc_ros/launch/sim/sim_server_eval.launch
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,44 @@
<arg name="map_server" default="1" />
<arg name="car_name" default="car" />

<!-- <arg name="out_path" default="/home/rb/hackathon_data/e2e_eval/model_test" />
<arg name="model_path" default="/home/rb/hackathon_data/aml_outputs/log_output/normal-kingfish/GPTiros_e2e_8gpu_2022-02-17_1645120431.7528405_2022-02-17_1645120431.7528613/model/epoch10.pth.tar" /> -->
<!-- logging output paths -->
<arg name="out_path" default="/home/rb/hackathon_data_premium/e2e_eval/model_test" />

<!-- current model! -->
<!-- <arg name="out_path" default="/home/rb/hackathon_data/e2e_eval/model_test" />
<arg name="model_path" default="/home/rb/hackathon_data/aml_outputs/log_output/deep-adder/GPTnips_e2e_2022-04-06_1649220643.529554_2022-04-06_1649220643.5295672/model/epoch21.pth.tar" /> -->
<!-- action model -->
<arg name="model_path_act" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/hvd_test_16/GPTcorl_scratch_trainm_e2e_statet_pointnet_traini_1_nla_12_nhe_8_statel_0.01_2022-06-02_1654131996.2524076_2022-06-02_1654131996.2524228/model/epoch30.pth.tar" />

<!-- <arg name="out_path" default="/home/rb/hackathon_data/e2e_eval/model_test" />
<arg name="model_path" default="/home/rb/hackathon_data/aml_outputs/log_output/renewed-colt/GPTnips_8gpu_relu_e2e_slow_bias_2022-04-02_1648869033.3693378_2022-04-02_1648869033.36935/model/epoch22.pth.tar" /> -->
<!-- map model -->
<arg name="use_map" default="true" />
<arg name="model_path_map" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/mapscratch_new_0/GPTcorl_map_trainm_map_sta_pointnet_traini_1_nla_12_nhe_8_2022-05-31_1653978768.732001_2022-05-31_1653978768.7320147/model/epoch28.pth.tar" />

<!-- <arg name="out_path" default="/home/rb/hackathon_data/e2e_eval/model_test" />
<arg name="model_path" default="/home/rb/hackathon_data/aml_outputs/log_output/mushr_nips/GPTnips_8gpu_relu_e2e_slow_2022-03-31_1648698385.6160257_2022-03-31_1648698385.6160388/model/epoch5.pth.tar" /> -->
<!-- localization model -->
<arg name="use_loc" default="false" />
<arg name="model_path_loc" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/locscratch_new_0/GPTcorl_loc_trainm_loc_sta_pointnet_lr_6e-5_traini_1_nla_12_nhe_8_locx_0.01_locy_1_loca_10_locd_joint_2022-05-31_1653978601.5423563_2022-05-31_1653978601.5423756/model/epoch30.pth.tar" />

<!-- current model! -->
<!-- <arg name="out_path" default="/home/rb/hackathon_data/e2e_eval/model_test" />
<arg name="model_path" default="/home/rb/hackathon_data/aml_outputs/log_output/tough-mongoose/GPTnips_e2e_2022-04-07_1649355173.7379868_2022-04-07_1649355173.7379994/model/epoch22.pth.tar" /> -->
<arg name="deployment_map" default="train" />
<!-- <arg name="deployment_map" default="test" /> -->

<!-- current model! -->
<arg name="out_path" default="/home/rb/hackathon_data_premium/e2e_eval/model_test" />
<arg name="model_path" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/pretrain_new_0/GPTcorl_scratch_trainm_e2e_statet_pointnet_traini_1_nla_12_nhe_8_statel_0.01_2022-05-31_1654007043.7606997_2022-05-31_1654007043.760718/model/epoch14.pth.tar" />
<!-- <arg name="model_path" default="/home/rb/hackathon_data_premium/aml_outputs/log_output/pretrain_scratch_0/GPTcorl_scratch_trainm_e2e_statet_pointnet_traini_1_nla_6_nhe_8_statel_0.01_2022-05-23_1653348711.3057182_2022-05-23_1653348711.3057299/model/epoch12.pth.tar" /> -->
<group if="$(eval arg('deployment_map') == 'train')">
<include file="$(find mushr_rhc_ros)/launch/map_server.launch" />
</group>

<group if="$(arg map_server)">
<include file="$(find mushr_rhc_ros)/launch/map_server.launch" />
<group if="$(eval arg('deployment_map') == 'test')">
<include file="$(find mushr_rhc_ros)/launch/map_server_test_env.launch" />
</group>

<group ns="$(arg car_name)">
<node pkg="mushr_rhc_ros" type="rhcnode_network_pcl_new.py" name="rhcontroller" output="screen">
<env name="RHC_USE_CUDA" value="0" />

<param name="deployment_map" value="$(arg deployment_map)" />

<param name="out_path" value="$(arg out_path)" />
<param name="model_path" value="$(arg model_path)" />
<param name="model_path_act" value="$(arg model_path_act)" />
<param name="model_path_map" value="$(arg model_path_map)" />
<param name="model_path_loc" value="$(arg model_path_loc)" />

<param name="use_map" value="$(arg use_map)" />
<param name="use_loc" value="$(arg use_loc)" />

<param name="inferred_pose_t" value="car_pose" />

Expand Down
6 changes: 3 additions & 3 deletions mushr_rhc_ros/src/rhcnode_network_pcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, dtype, params, logger, name):
# mapping model
if self.use_map:

saved_map_model_path = '/home/rb/hackathon_data_premium/aml_outputs/log_output/mapscratch_new_0/GPTcorl_map_trainm_map_sta_pointnet_traini_1_nla_12_nhe_8_2022-05-31_1653978768.732001_2022-05-31_1653978768.7320147/model/epoch8.pth.tar'
saved_map_model_path = '/home/rb/hackathon_data_premium/aml_outputs/log_output/mapscratch_new_0/GPTcorl_map_trainm_map_sta_pointnet_traini_1_nla_12_nhe_8_2022-05-31_1653978768.732001_2022-05-31_1653978768.7320147/model/epoch28.pth.tar'

mconf_map = GPTConfig(block_size, max_timestep,
n_layer=12, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True,
Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(self, dtype, params, logger, name):
# localization model
if self.use_loc:

saved_loc_model_path = '/home/rb/hackathon_data_premium/aml_outputs/log_output/locscratch_new_0/GPTcorl_loc_trainm_loc_sta_pointnet_lr_6e-5_traini_1_nla_12_nhe_8_locx_0.01_locy_1_loca_10_locd_joint_2022-05-31_1653978601.5423563_2022-05-31_1653978601.5423756/model/epoch8.pth.tar'
saved_loc_model_path = '/home/rb/hackathon_data_premium/aml_outputs/log_output/locscratch_new_0/GPTcorl_loc_trainm_loc_sta_pointnet_lr_6e-5_traini_1_nla_12_nhe_8_locx_0.01_locy_1_loca_10_locd_joint_2022-05-31_1653978601.5423563_2022-05-31_1653978601.5423756/model/epoch26.pth.tar'

mconf_loc = GPTConfig(block_size, max_timestep,
n_layer=12, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True,
Expand Down Expand Up @@ -588,7 +588,7 @@ def prepare_model_inputs(self):

self.act_lock.acquire()
for act in self.q_actions.queue:
x_act[0,idx] = torch.tensor(act)
x_act[0,idx] = torch.tensor(pre.norm_angle(act))
idx+=1
self.act_lock.release()

Expand Down
40 changes: 24 additions & 16 deletions mushr_rhc_ros/src/rhcnode_network_pcl_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def __init__(self, dtype, params, logger, name):
self.default_angle = 0.0
self.nx = None
self.ny = None
self.use_map = False
self.use_loc = True

self.points_viz_list = None
self.map_recon = None
self.loc_counter = 0
Expand All @@ -101,8 +100,12 @@ def __init__(self, dtype, params, logger, name):
self.device = device
self.clip_len = 16

# tests for IROS
saved_model_path = rospy.get_param("~model_path", 'default_value')
self.map_type = rospy.get_param("~deployment_map", 'train')

self.use_map = rospy.get_param("~use_map", False)
self.use_loc = rospy.get_param("~use_loc", False)

saved_model_path_action = rospy.get_param("~model_path_act", 'default_value')
self.out_path = rospy.get_param("~out_path", 'default_value')

vocab_size = 100
Expand All @@ -116,11 +119,11 @@ def __init__(self, dtype, params, logger, name):
map_decoder='deconv', map_recon_dim=64, freeze_core=False,
state_loss_weight=0.1,
loc_x_loss_weight=0.01, loc_y_loss_weight=0.1, loc_angle_loss_weight=10.0,
loc_decoder_type='separate')
loc_decoder_type='joint')
model = GPT(mconf, device)
# model=torch.nn.DataParallel(model)

checkpoint = torch.load(saved_model_path, map_location=device)
checkpoint = torch.load(saved_model_path_action, map_location=device)
# old code for loading model
model.load_state_dict(checkpoint['state_dict'])
# new code for loading mode
Expand Down Expand Up @@ -153,8 +156,8 @@ def __init__(self, dtype, params, logger, name):

# mapping model
if self.use_map:

saved_map_model_path = '/home/rb/hackathon_data_premium/aml_outputs/log_output/mapscratch_new_0/GPTcorl_map_trainm_map_sta_pointnet_traini_1_nla_12_nhe_8_2022-05-31_1653978768.732001_2022-05-31_1653978768.7320147/model/epoch8.pth.tar'
saved_map_model_path = rospy.get_param("~model_path_map", '')

mconf_map = GPTConfig(block_size, max_timestep,
n_layer=12, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True,
Expand Down Expand Up @@ -183,8 +186,8 @@ def __init__(self, dtype, params, logger, name):

# localization model
if self.use_loc:

saved_loc_model_path = '/home/rb/hackathon_data_premium/aml_outputs/log_output/locscratch_new_0/GPTcorl_loc_trainm_loc_sta_pointnet_lr_6e-5_traini_1_nla_12_nhe_8_locx_0.01_locy_1_loca_10_locd_joint_2022-05-31_1653978601.5423563_2022-05-31_1653978601.5423756/model/epoch17.pth.tar'
saved_loc_model_path = rospy.get_param("~model_path_loc", '')

mconf_loc = GPTConfig(block_size, max_timestep,
n_layer=12, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True,
Expand Down Expand Up @@ -232,7 +235,7 @@ def __init__(self, dtype, params, logger, name):
# set timer callbacks for visualization
rate_map_display = 1.0
rate_loc_display = 20
# self.map_viz_timer = rospy.Timer(rospy.Duration(1.0 / rate_map_display), self.map_viz_cb)
self.map_viz_timer = rospy.Timer(rospy.Duration(1.0 / rate_map_display), self.map_viz_cb)
self.map_viz_loc = rospy.Timer(rospy.Duration(1.0 / rate_loc_display), self.loc_viz_cb)


Expand All @@ -249,8 +252,11 @@ def start(self):
self.logger.info("Initialized")

# set initial pose for the car in the very first time in an allowable region
self.send_initial_pose()
# self.send_initial_pose_12f()
if self.map_type == 'train':
self.send_initial_pose()
else:
self.send_initial_pose_12f()

self.time_started = rospy.Time.now()

# wait until we actually have a car pose
Expand Down Expand Up @@ -544,7 +550,7 @@ def prepare_model_inputs(self, queue_type):
self.large_queue_lock.release()
for idx, element in enumerate(queue_list):
x_imgs[0,idx,:] = torch.tensor(element[0])
x_act[0,idx] = torch.tensor(element[1])
x_act[0,idx] = torch.tensor(pre.norm_angle(element[1]))
# x_imgs = x_imgs.contiguous().view(1, self.clip_len, 200*200)
x_imgs = x_imgs.to(self.device)
x_act = x_act.view(1, self.clip_len , 1)
Expand Down Expand Up @@ -579,8 +585,10 @@ def check_reset(self, rate_hz):
print("Distance: {} | Time: {} | Time so far: {}".format(self.distance_so_far, delta_time, self.time_so_far))
with open(self.file_name,'a') as fd:
fd.write(str(self.distance_so_far)+','+str(self.time_so_far)+'\n')
self.send_initial_pose()
# self.send_initial_pose_12f()
if self.map_type == 'train':
self.send_initial_pose()
else:
self.send_initial_pose_12f()
rospy.loginfo("Got stuck, resetting pose of the car to default value")
msg = String()
msg.data = "got stuck"
Expand Down

0 comments on commit 4323728

Please sign in to comment.