Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rogeriobonatti committed Feb 24, 2022
1 parent 865091f commit 3f5b767
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions mushr_rhc_ros/src/rhcnode_network_shuang.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ def __init__(self, dtype, params, logger, name):
device = torch.device('cuda')

self.clip_len = 16

# fine-tuned in real life
saved_model_path = '/home/robot/weight_files/epoch0.pth.tar'
# saved_model_path = '/home/rb/downloaded_models/epoch30.pth.tar'
saved_model_path = '/home/robot/weight_files/epoch15.pth.tar'

# best so far
# saved_model_path = '/home/robot/weight_files/epoch15.pth.tar'
# saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_0/GPTgpt_resnet18_4gpu_2022-01-24_1642987604.6403077_2022-01-24_1642987604.640322/model/epoch15.pth.tar'
# saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_8_exp2/GPTgpt_resnet18_8gpu_exp2_2022-01-25_1643076745.003202_2022-01-25_1643076745.0032148/model/epoch12.pth.tar'
vocab_size = 100
Expand Down Expand Up @@ -161,7 +166,7 @@ def start(self):
self.last_action = self.apply_network()
self.q_actions.get() # remove the oldest action from the queue
self.q_actions.put(self.last_action)
rospy.loginfo("Applied network: "+str(self.last_action))
# rospy.loginfo("Applied network: "+str(self.last_action))
self.compute_network = False

self.publish_traj(self.default_speed, self.last_action)
Expand Down Expand Up @@ -218,7 +223,7 @@ def apply_network(self):
t = t.to(self.device)

finish_processing = time.time()
rospy.loginfo("processing delay: "+str(finish_processing-start))
# rospy.loginfo("processing delay: "+str(finish_processing-start))

# organize the action input
with torch.set_grad_enabled(False):
Expand Down

0 comments on commit 3f5b767

Please sign in to comment.