Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
  • Loading branch information
rogeriobonatti committed Feb 24, 2022
1 parent c926772 commit 4bd60d0
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mushr_rhc_ros/src/rhcnode_network_shuang.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import torch
from mingpt.model_resnetdirect import ResnetDirect, ResnetDirectWithActions
# from mingpt.model_musher import GPT, GPTConfig
# from mingpt.model_mushr_rogerio import GPT, GPTConfig
from mingpt.model_mushr_iros import GPT, GPTConfig
from mingpt.model_mushr_rogerio import GPT, GPTConfig
# from mingpt.model_mushr_iros import GPT, GPTConfig
import preprocessing_utils as pre
from visualization_msgs.msg import Marker

Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(self, dtype, params, logger, name):
self.clip_len = 16

# fine-tuned in real life
saved_model_path = '/home/robot/weight_files/epoch0.pth.tar'
# saved_model_path = '/home/robot/weight_files/epoch0.pth.tar'
# saved_model_path = '/home/rb/downloaded_models/epoch30.pth.tar'

# tests for IROS
Expand All @@ -91,7 +91,7 @@ def __init__(self, dtype, params, logger, name):
# saved_model_path = '/home/rb/downloaded_models/epoch30.pth.tar'

# saved_model_path = '/home/robot/weight_files/epoch15.pth.tar'
# saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_0/GPTgpt_resnet18_4gpu_2022-01-24_1642987604.6403077_2022-01-24_1642987604.640322/model/epoch15.pth.tar'
saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_0/GPTgpt_resnet18_4gpu_2022-01-24_1642987604.6403077_2022-01-24_1642987604.640322/model/epoch15.pth.tar'
# saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_8_exp2/GPTgpt_resnet18_8gpu_exp2_2022-01-25_1643076745.003202_2022-01-25_1643076745.0032148/model/epoch12.pth.tar'
vocab_size = 100
block_size = self.clip_len * 2
Expand Down Expand Up @@ -282,7 +282,7 @@ def apply_network(self):
action_pred, loss = self.model(states=x_imgs, actions=x_act, targets=x_act, timesteps=t)
action_pred = action_pred[0,self.clip_len-1,0].cpu().flatten().item()
finished_network = time.time()
rospy.loginfo("network delay: "+str(finished_network-finish_processing))
# rospy.loginfo("network delay: "+str(finished_network-finish_processing))

# de-normalize
action_pred = pre.denorm_angle(action_pred)
Expand Down

0 comments on commit 4bd60d0

Please sign in to comment.