From c25b8b8c313a0b58558459f37a5e011619e62393 Mon Sep 17 00:00:00 2001 From: Rogerio Bonatti Date: Tue, 24 May 2022 19:55:32 -0700 Subject: [PATCH] included new pointnet models deployment --- .../launch/sim/sim_server_eval.launch | 5 +- mushr_rhc_ros/src/mingpt/model_mushr_new2.py | 618 ++++++++++++ mushr_rhc_ros/src/rhcnode_network_corl.py | 900 ++++++++++++++++++ mushr_rhc_ros/src/rhcnode_network_pcl.py | 105 +- 4 files changed, 1575 insertions(+), 53 deletions(-) create mode 100644 mushr_rhc_ros/src/mingpt/model_mushr_new2.py create mode 100755 mushr_rhc_ros/src/rhcnode_network_corl.py diff --git a/mushr_rhc_ros/launch/sim/sim_server_eval.launch b/mushr_rhc_ros/launch/sim/sim_server_eval.launch index 43e2107..5c6bc92 100644 --- a/mushr_rhc_ros/launch/sim/sim_server_eval.launch +++ b/mushr_rhc_ros/launch/sim/sim_server_eval.launch @@ -21,8 +21,9 @@ --> - - + + + diff --git a/mushr_rhc_ros/src/mingpt/model_mushr_new2.py b/mushr_rhc_ros/src/mingpt/model_mushr_new2.py new file mode 100644 index 0000000..94d44f2 --- /dev/null +++ b/mushr_rhc_ros/src/mingpt/model_mushr_new2.py @@ -0,0 +1,618 @@ +""" +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier +""" + +from base64 import encode +import math +import logging +from collections import OrderedDict +import time + +import torch +import torch.nn as nn +from torch.nn import functional as F +from resnet_custom import resnet18_custom, resnet50_custom +from mingpt.pointnet import PCL_encoder as PointNet + +logger = logging.getLogger(__name__) + +import numpy as np + +import sys +sys.path.append('../') +from models.compass.select_backbone import select_resnet + +class GELU(nn.Module): + def forward(self, input): + return F.gelu(input) + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, block_size, max_timestep, **kwargs): + self.block_size = block_size + self.max_timestep = max_timestep + for k,v in kwargs.items(): + setattr(self, k, v) + print(k, v) + +class GPT1Config(GPTConfig): + """ GPT-1 like network roughly 125M params """ + n_layer = 12 + n_head = 12 + n_embd = 768 + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) + # .view(1, 1, config.block_size, config.block_size)) + self.register_buffer("mask", torch.tril(torch.ones(config.block_size + 1, config.block_size + 1)) + .view(1, 1, config.block_size + 1, config.block_size + 1)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + GELU(), + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, config, device): + super().__init__() + + self.config = config + self.device = device + + self.model_type = config.model_type + self.use_pred_state = config.use_pred_state + + self.map_recon_dim = config.map_recon_dim + self.freeze_core = config.freeze_core + + # input embedding stem + # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd)) + # self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd)) + + self.drop = nn.Dropout(config.embd_pdrop) + + self.embed_timestep_global = nn.Embedding(config.max_timestep*2, config.n_embd) # unique for each token + self.embed_timestep_local = nn.Embedding(config.max_timestep, config.n_embd) # is the same for S_t and a_t + + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + + self.block_size = config.block_size + self.apply(self._init_weights) + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + # DEFINE TOKENIZERS / ENCODERS + + # for the state + if config.state_tokenizer == 'conv2D': + self.state_encoder = nn.Sequential(nn.Conv2d(1, 32, 8, stride=4, padding=0), nn.ReLU(), + nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), + nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), + nn.Flatten(), nn.Linear(36864, config.n_embd), nn.ReLU()) + elif config.state_tokenizer == 'resnet18': + self.state_encoder = nn.Sequential(resnet18_custom(pretrained=False, clip_len=1), nn.ReLU(), + nn.Linear(1000, config.n_embd), nn.Tanh()) + elif config.state_tokenizer == 'pointnet': + self.state_encoder = PointNet(config.n_embd) + + # for the action + # simple action embedding as single linear layer + self.action_embeddings = nn.Sequential( + nn.Linear(1, config.n_embd) + ) + # self.action_embeddings = nn.Sequential( + # nn.Linear(1, 32), + # nn.ReLU(), + # nn.Linear(32, 64), + # nn.ReLU(), + # nn.Linear(64, config.n_embd) + # ) + + # DEFINE DECODERS + + # add map decoder + # from almost all the tokens of s_0:t and a_0:t-1, predict the overall map centered in the middle pose of the vehicle + # a_t is not taken into account because it's still being predicted in deployment and does not affect the map + encoded_feat_dim = config.n_embd * (config.block_size - 1) + if config.map_decoder == 'mlp': + #MLP map decoder + self.map_decoder = nn.Sequential(nn.Linear(encoded_feat_dim, 1024), + nn.ReLU(), + nn.Linear(1024, 2048), + nn.ReLU(), + nn.Linear(2048, 64*64), + nn.Tanh()) + elif config.map_decoder == 'deconv': + if self.map_recon_dim == 64: + # conv2d map decoder - original + self.map_decoder = nn.Sequential(nn.Linear(encoded_feat_dim, 4096), + nn.ReLU(), + Reshape(16, 16, 16), + MapDecoder_2x_Deconv(16)) + elif self.map_recon_dim == 128: + # conv2d map decoder - new trial + self.map_decoder = nn.Sequential(nn.Linear(encoded_feat_dim, 4096), + nn.ReLU(), + Reshape(16, 16, 16), + MapDecoder_4x_Deconv128px(16)) + else: + print('Not support!') + else: + print('Not support!') + + # from the token of s_t, predict next action a_t (like a policy) + self.predict_action = nn.Sequential( + nn.Linear(config.n_embd, 1) + ) + + # from the tokens of s_t and a_t, predict embedding of next state s_t+1 (like a dynamics model) + # we don't predict the next state directly because it's potentially an image and that would require large capacity + self.predict_state_token = nn.Sequential( + nn.Linear(2*config.n_embd, config.n_embd) + ) + + # from the tokens of s_t, s_t-1 a_t-1, predict delta in pose from t-1 to t + self.predict_pose_linear = nn.Sequential( + nn.Linear(3*config.n_embd, 3) + ) + + # from the tokens of s_t, s_t-1 a_t-1, predict delta in pose from t-1 to t + self.predict_pose_deep = nn.Sequential( + nn.Linear(3*config.n_embd, 128), + nn.ReLU(), + nn.Linear(128, 64), + nn.ReLU(), + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 3) + ) + + criterion = torch.nn.MSELoss(reduction='mean') + self.criterion = criterion.cuda(device) + + self.load_pretrained_model_weights(config.pretrained_model_path) + + def load_pretrained_model_weights(self, model_path): + if model_path: + checkpoint = torch.load(model_path, map_location=self.device) + + # remove the 'module.' string from dict if the saved model was dataparallel + new_checkpoint = OrderedDict() + for key in checkpoint['state_dict'].keys(): + if key.startswith('module.'): + new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key] + else: + new_checkpoint[key] = checkpoint['state_dict'][key] + + # find the common keys + + # separate the components from the main transformer, tokenizer, and decoders + + self.load_state_dict(new_checkpoint, strict=False) + print('Successfully loaded pretrained checkpoint: {}.'.format(model_path)) + + # for key in ckpt: + # print(key) + + # for param in self.parameters(): + # print(param) + + # ckpt = torch.load(model_path)['state_dict'] # COMPASS checkpoint format. + # ckpt2 = {} + # ckpt3 = {} + # ckpt4 = {} + # for key in ckpt: + # print(key) + # if key.startswith('blocks'): + # ckpt2[key.replace('blocks.', '')] = ckpt[key] + # if key.startswith('state_encoder'): + # ckpt3[key.replace('state_encoder.', '')] = ckpt[key] + # if key.startswith('action_embeddings'): + # ckpt4[key.replace('action_embeddings.', '')] = ckpt[key] + + # self.blocks.load_state_dict(ckpt) + # self.state_encoder.load_state_dict(ckpt3) + # self.action_embeddings.load_state_dict(ckpt4) + # print('Successfully loaded pretrained checkpoint: {}.'.format(model_path)) + else: + print('Train from scratch.') + + + def reconstruction_loss(self, pred, target): + loss = F.l1_loss(pred, target) + return loss + + def _initialize_weights(self, module): + for name, param in module.named_parameters(): + if 'bias' in name: + nn.init.constant_(param, 0.0) + elif 'weight' in name: + nn.init.orthogonal_(param, 0.1) + + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def configure_optimizers_frozen(self, train_config): + # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=train_config.learning_rate, betas=train_config.betas) + non_frozen_params_list = list(self.map_decoder.parameters()) + \ + list(self.predict_action.parameters()) + \ + list(self.predict_pose_linear.parameters()) + \ + list(self.predict_pose_deep.parameters()) + \ + list(self.predict_state_token.parameters()) + optimizer = torch.optim.Adam(non_frozen_params_list, lr=train_config.learning_rate, betas=train_config.betas) + # for p in self.parameters(): + # if p.requires_grad: + # print(p) + return optimizer + + def configure_optimizers(self, train_config): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + # whitelist_weight_modules = (torch.nn.Linear, ) + whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, torch.nn.BatchNorm3d, torch.nn.BatchNorm2d, torch.nn.BatchNorm1d) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('embed_timestep_global.weight') + no_decay.add('embed_timestep_local.weight') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + return optimizer + + # state, and action + def forward(self, states, actions, targets=None, gt_map=None, timesteps=None, poses=None, compute_loss=True): + # states: (batch, block_size, 4*84*84) + # actions: (batch, block_size, 1) + # targets: (batch, block_size, 1) + # timesteps: (batch, 1, 1) + if self.config.state_tokenizer == 'resnet18': + state_embeddings = self.state_encoder(states.reshape(-1, 1, 200 , 200).type(torch.float32).contiguous()) # (batch * block_size, n_embd) + state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) + elif self.config.state_tokenizer == 'conv2D': + state_embeddings = self.state_encoder(states.reshape(-1, 1, 244 , 244).type(torch.float32).contiguous()) # (batch * block_size, n_embd) + state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) + elif self.config.state_tokenizer == 'pointnet': + state_embeddings = self.state_encoder(states.reshape(-1, 2 , 720).type(torch.float32).contiguous()) # (batch * block_size, n_embd) + state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) + else: + print('Not supported!') + + if actions is not None and self.model_type == 'GPT': + + B, N, C = actions.shape + action_embeddings = self.action_embeddings(actions.view(B*N, C)).view(B, N, -1) # (batch, block_size, n_embd) + + #token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) + token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=self.device) + + token_embeddings[:,::2,:] = state_embeddings + token_embeddings[:,1::2,:] = action_embeddings + else: + raise NotImplementedError() + + # batch_size = states.shape[0] + # all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd + # position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] + #position_embeddings = torch.gather(all_global_pos_emb, 1, self.pos_emb[:, :token_embeddings.shape[1], :].type(torch.long)) + + # original code + # t = np.ones((B, 1, 1), dtype=int) * 7 + # t = torch.tensor(t) + # timesteps = t + # timesteps = timesteps.to(self.device) + # position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps.to(self.device), self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] + + # lines used for debugging + # t_local = 0,1,2,...,15 + # t_global = 0,1,2,...,31 -> up until 2*t_max + # timesteps = torch.arange(0, N).view(1,-1) + # self.embed_timestep_global = self.embed_timestep_global.cpu() + # self.embed_timestep_local = self.embed_timestep_local.cpu() + + # calculate global positional embedding (unique for each token) + timesteps_global = torch.arange(timesteps.min(), timesteps.min()+2*timesteps.shape[1]).view(1,-1).to(self.device) + position_embeddings_global = self.embed_timestep_global(timesteps_global) + position_embeddings_global = position_embeddings_global.repeat(B,1,1) + + # calculate local positional embedding (unique for each state-action pair) + position_embeddings_local = self.embed_timestep_local(timesteps) + position_embeddings_local = torch.repeat_interleave(position_embeddings_local, 2, dim=1) + # position_embeddings_local = position_embeddings_local.repeat(B,1,1) + + x = self.drop(token_embeddings + position_embeddings_global + position_embeddings_local) + x = self.blocks(x) + x = self.ln_f(x) + + if self.config.train_mode == 'e2e': + # from the tokens of s_t, predict next action a_t (like a policy) + action_preds = self.predict_action(x[:, ::2, :]) + # from the tokens of s_t and a_t, predict embedding of next state s_t+1 + B, N, D = x.shape + # -1 to ignore the last s_t, a_t pair because we will not predict s_t+1 at the very end + state_feat = x.reshape(B, int(N/2), -1)[:,:-1,:] + state_preds = self.predict_state_token(state_feat) + elif self.config.train_mode == 'map': + percep_feat = x[:, :-1, :] + B, N, D = percep_feat.shape + feat = percep_feat.reshape(B, -1) # reshape to a vector + # from all the tokens of s_0:t and a_0:t, predict the overall map centered in the middle pose of the vehicle + map_recon = self.map_decoder(feat) + elif self.config.train_mode == 'loc': + # from the tokens of s_t, s_t-1 a_t-1, predict delta in pose from t-1 to t + B, N, D = x.shape + # will not use the last tokens, + pose_feat = torch.zeros(B, 3*int(N/2-1), D, device=self.device) + pose_feat[:,::3,:] = x[:, ::2, :][:,:-1,:] + pose_feat[:,1::3,:] = x[:, 1::2, :][:,:-1,:] + pose_feat[:,2::3,:] = x[:, 2::2, :] + pose_feat = pose_feat.reshape(B, int(N/2)-1, -1) + pose_preds = self.predict_pose_deep(pose_feat) + # pose_preds[:,:,2:] = torch.tanh(pose_preds[:,:,2:]) + elif self.config.train_mode == 'joint': + action_preds = self.predict_action(x[:, ::2, :]) + percep_feat = x[:, ::2, :] + B, N, D = percep_feat.shape + feat = percep_feat.reshape(B, -1) # reshape to a vector + map_recon = self.map_decoder(feat) + pose_preds = self.predict_pose_deep(x[:, ::2, :]) + # pose_preds[:,:,2:] = torch.tanh(pose_preds[:,:,2:]) + else: + print('Not supported!') + + + loss = None + loss_act = None + loss_states = None + loss_translation_x = None + loss_translation_y = None + loss_angle = None + if targets is not None: + if self.config.train_mode == 'map': + if compute_loss: + loss = self.criterion(map_recon.reshape(-1, self.map_recon_dim, self.map_recon_dim), gt_map) + return map_recon, loss + elif self.config.train_mode == 'e2e': + # loss over N timesteps + if compute_loss: + loss_act = self.criterion(actions, action_preds) + # for states, try to predict the next state based on current state and action + loss_states = self.criterion(x[:, 2::2, :], state_preds) + if self.config.use_pred_state: + w_state = self.config.state_loss_weight + else: + w_state = 0.0 + loss = loss_act + w_state*loss_states + return action_preds, loss, loss_act, loss_states + elif self.config.train_mode == 'loc': + # loss over N timesteps + if compute_loss: + # the :-1 ignores the pose prediction loss computed at the last element + # because the last state after sequence was garbage (zero) + loss_translation_x = self.criterion(poses[:,:,0],pose_preds[:,:,0]) + loss_translation_y = self.criterion(poses[:,:,1],pose_preds[:,:,1]) + loss_angle = self.criterion(poses[:,:,2],pose_preds[:,:,2]) + # scale angle loss, similar to DeepVO paper (they use factor of 100, but car is moving faster) + loss = self.config.loc_x_loss_weight*loss_translation_x + \ + self.config.loc_y_loss_weight*loss_translation_y + \ + self.config.loc_angle_loss_weight*loss_angle + return pose_preds, loss, loss_translation_x, loss_translation_y, loss_angle + elif self.config.train_mode == 'joint': + return action_preds, map_recon, pose_preds + +class MapDecoder_4x_Deconv(nn.Module): + def __init__(self, in_channels=384): + super().__init__() + + self.decoder = nn.Sequential( + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(in_channels, 256, kernel_size=3, stride=2, padding=1), output_size=(50, 50)), + nn.BatchNorm2d(256), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1), output_size=(100, 100)), + nn.BatchNorm2d(128), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1), output_size=(200, 200)), + nn.BatchNorm2d(64), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1), output_size=(400, 400)), + ) + + def forward(self, x): + return self.decoder(x) + +class MapDecoder_2x_Deconv(nn.Module): + def __init__(self, in_channels=768): + super().__init__() + + # The parameters for ConvTranspose2D are from the PyTorch repo. + # Ref: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + # Ref: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + # Ref: https://discuss.pytorch.org/t/the-output-size-of-convtranspose2d-differs-from-the-expected-output-size/1876/13 + # Ref: (padding) https://towardsdatascience.com/what-is-transposed-convolutional-layer-40e5e6e31c11 + self.decoder = nn.Sequential( + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(in_channels, 8, kernel_size=3, stride=2, padding=1), output_size=(32, 32)), + nn.BatchNorm2d(8), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1), output_size=(64, 64)), + nn.Tanh() + ) + + def forward(self, x): + return self.decoder(x) + +class MapDecoder_4x_Deconv128px(nn.Module): + def __init__(self, in_channels=384): + super().__init__() + + self.decoder = nn.Sequential( + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(in_channels, 256, kernel_size=3, stride=2, padding=1), output_size=(32, 32)), + nn.BatchNorm2d(256), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1), output_size=(64, 64)), + nn.BatchNorm2d(128), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(128, 1, kernel_size=3, stride=2, padding=1), output_size=(128, 128)), + nn.Tanh() + ) + + def forward(self, x): + return self.decoder(x) + + +class MapDecoder_5x_Deconv_64output(nn.Module): + def __init__(self, in_channels=768): + super().__init__() + self.decoder = nn.Sequential( + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(in_channels, 64, kernel_size=4, stride=1, padding=2), output_size=(8, 8)), + nn.BatchNorm2d(8), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2), output_size=(16, 16)), + nn.BatchNorm2d(8), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(32, 16, kernel_size=6, stride=2, padding=2), output_size=(32, 32)), + nn.BatchNorm2d(8), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1), output_size=(32, 32)), + nn.BatchNorm2d(8), + nn.ReLU(), + ConvTranspose2d_FixOutputSize(nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1), output_size=(64, 64)), + nn.Tanh() + ) + + def forward(self, x): + return self.decoder(x) + + +class ConvTranspose2d_FixOutputSize(nn.Module): + """ + A wrapper to fix the output size of ConvTranspose2D. + Ref: https://discuss.pytorch.org/t/the-output-size-of-convtranspose2d-differs-from-the-expected-output-size/1876/13 + Ref: (other alternatives) https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + def __init__(self, conv, output_size): + super(ConvTranspose2d_FixOutputSize, self).__init__() + self.output_size = output_size + self.conv = conv + + def forward(self, x): + x = self.conv(x, output_size=self.output_size) + return x + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + def forward(self, x): + return x.view((x.size(0),)+self.shape) \ No newline at end of file diff --git a/mushr_rhc_ros/src/rhcnode_network_corl.py b/mushr_rhc_ros/src/rhcnode_network_corl.py new file mode 100755 index 0000000..d79c1cd --- /dev/null +++ b/mushr_rhc_ros/src/rhcnode_network_corl.py @@ -0,0 +1,900 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2019, The Personal Robotics Lab, The MuSHR Team, The Contributors of MuSHR +# License: BSD 3-Clause. See LICENSE.md file in root directory. + +# from torchsummary import summary +import sys +import os +import signal +import threading +import random +import numpy as np +from queue import Queue +import time +from collections import OrderedDict +import math +import copy + +import rospy +from ackermann_msgs.msg import AckermannDriveStamped +from geometry_msgs.msg import Point, PoseStamped, PoseWithCovarianceStamped +from std_msgs.msg import ColorRGBA, Empty, String +from std_srvs.srv import Empty as SrvEmpty +from visualization_msgs.msg import Marker +from sensor_msgs.msg import LaserScan + +import logger +import parameters +import rhcbase +import rhctensor +import utilss +import librhc.utils as utils_other + +import torch +from mingpt.model_resnetdirect import ResnetDirect, ResnetDirectWithActions +# from mingpt.model_musher import GPT, GPTConfig +# from mingpt.model_mushr_rogerio import GPT, GPTConfig +from mingpt.model_mushr_nips import GPT, GPTConfig +import preprocessing_utils as pre +from visualization_msgs.msg import Marker + +# import torch_tensorrt + + +class RHCNode(rhcbase.RHCBase): + def __init__(self, dtype, params, logger, name): + rospy.init_node(name, anonymous=True, disable_signals=True) + + super(RHCNode, self).__init__(dtype, params, logger) + + self.scan_lock = threading.Lock() + self.pos_lock = threading.Lock() + self.act_lock = threading.Lock() + self.curr_pose = None + + self.reset_lock = threading.Lock() + self.inferred_pose_lock = threading.Lock() + self.inferred_pose_lock_prev = threading.Lock() + self._inferred_pose = None + self._inferred_pose_prev = None + self._time_of_inferred_pose = None + self._time_of_inferred_pose_prev = None + + self.hp_zerocost_ids = None + self.hp_map = None + self.hp_world = None + self.time_started_goal = None + self.num_trials = 0 + + self.cur_rollout = self.cur_rollout_ip = None + self.traj_pub_lock = threading.Lock() + + self.goal_event = threading.Event() + self.map_metadata_event = threading.Event() + self.ready_event = threading.Event() + self.events = [self.goal_event, self.map_metadata_event, self.ready_event] + self.run = True + + self.default_speed = 2.5 + # self.default_speed = 1.0 + self.default_angle = 0.0 + self.nx = None + self.ny = None + self.use_map = True + self.use_loc = True + self.points_viz_list = None + self.map_recon = None + self.loc_counter = 0 + + # network loading + print("Starting to load model") + os.environ["CUDA_VISIBLE_DEVICES"]=str(0) + device = torch.device('cuda') + # device = "cpu" + + self.device = device + self.clip_len = 16 + + # tests for IROS + saved_model_path = rospy.get_param("~model_path", 'default_value') + self.out_path = rospy.get_param("~out_path", 'default_value') + # saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/normal-kingfish/GPTiros_e2e_8gpu_2022-02-17_1645120431.7528405_2022-02-17_1645120431.7528613/model/epoch10.pth.tar' + + # saved_model_path = '/home/rb/downloaded_models/epoch30.pth.tar' + # saved_model_path = '/home/robot/weight_files/epoch15.pth.tar' + # saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_0/GPTgpt_resnet18_4gpu_2022-01-24_1642987604.6403077_2022-01-24_1642987604.640322/model/epoch15.pth.tar' + # saved_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/gpt_resnet18_8_exp2/GPTgpt_resnet18_8gpu_exp2_2022-01-25_1643076745.003202_2022-01-25_1643076745.0032148/model/epoch12.pth.tar' + vocab_size = 100 + block_size = self.clip_len * 2 + max_timestep = 7 + # mconf = GPTConfig(vocab_size, block_size, max_timestep, + # n_layer=6, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True, + # state_tokenizer='conv2D', train_mode='e2e', pretrained_model_path='') + mconf = GPTConfig(vocab_size, block_size, max_timestep, + n_layer=6, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True, + state_tokenizer='resnet18', train_mode='e2e', pretrained_model_path='', pretrained_encoder_path='', loss='MSE', + map_decoder='deconv', map_recon_dim=64) + model = GPT(mconf, device) + # model=torch.nn.DataParallel(model) + + checkpoint = torch.load(saved_model_path, map_location=device) + # old code for loading model + # model.load_state_dict(checkpoint['state_dict']) + # new code for loading mode + new_checkpoint = OrderedDict() + for key in checkpoint['state_dict'].keys(): + new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key] + model.load_state_dict(new_checkpoint) + + # ckpt = torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'] + # for key in ckpt: + # print('********',key) + # model.load_state_dict(torch.load('/home/rb/downloaded_models/epoch30.pth.tar')['state_dict'], strict=True) + + model.eval() + # model.half() + model.to(device) + + # inputs = [torch_tensorrt.Input( + # states_shape=[1, self.clip_len, 200*200], + # actions_shape=[1, self.clip_len , 1], + # targets_shape=[1, self.clip_len , 1], + # timesteps_shape=[1, 1, 1], + # dtype=torch.half, + # )] + # enabled_precisions = {torch.float, torch.half} + # trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions) + + self.model = model + print("Finished loading model") + + # mapping model + if self.use_map: + + # saved_map_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/tough-mongoose/GPTnips_map_2022-04-07_1649355173.7984643_2022-04-07_1649355173.7984767/model/epoch10.pth.tar' + + saved_map_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/humble-bee/GPTnips_map_finetune_nofreeze_2022-04-12_1649739603.323518_2022-04-12_1649739603.3235295/model/epoch29.pth.tar' + # saved_map_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/easy-hornet/GPTnips_map_finetune_2022-04-11_1649705561.8950737_2022-04-11_1649705561.895087/model/epoch15.pth.tar' + + map_mconf = GPTConfig(vocab_size, block_size, max_timestep, + n_layer=6, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True, + state_tokenizer='resnet18', train_mode='map', pretrained_model_path='', pretrained_encoder_path='', loss='MSE', + map_decoder='deconv', map_recon_dim=64) + map_model = GPT(map_mconf, device) + # map_model=torch.nn.DataParallel(map_model) + checkpoint = torch.load(saved_map_model_path, map_location=device) + + # old code for loading model + # map_model.load_state_dict(checkpoint['state_dict']) + # new code for loading mode + new_checkpoint = OrderedDict() + for key in checkpoint['state_dict'].keys(): + new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key] + map_model.load_state_dict(new_checkpoint) + + map_model.eval() + map_model.to(device) + self.map_model = map_model + + # localization model + if self.use_loc: + # saved_loc_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/perfect-weevil/GPTnips_loc_1000xangle_2022-04-08_1649429799.7040765_2022-04-08_1649429799.7040896/model/epoch13.pth.tar' + + # saved_loc_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/humble-bee/GPTnips_loc_1trans1000xangle_finetune_nofreeze_2022-04-12_1649739606.8248725_2022-04-12_1649739606.824885/model/epoch29.pth.tar' + saved_loc_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/easy-hornet/GPTnips_loc_1trans1000xangle_finetune_2022-04-11_1649705562.0525122_2022-04-11_1649705562.0525265/model/epoch29.pth.tar' + # saved_loc_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/vital-ray/GPTnips_loc_0p01trans100xangle_2022-04-11_1649694076.4556465_2022-04-11_1649694076.4556613/model/epoch11.pth.tar' + # saved_loc_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/literate-flea/GPTnips_loc_0p01trans1xangle_2022-04-08_1649430057.8460536_2022-04-08_1649430057.8460662/model/epoch9.pth.tar' + # saved_loc_model_path = '/home/rb/hackathon_data/aml_outputs/log_output/stirring-cricket/GPTnips_loc_1xangle_2022-04-07_1649355866.8893065_2022-04-07_1649355866.8893213/model/epoch9.pth.tar' + loc_mconf = GPTConfig(vocab_size, block_size, max_timestep, + n_layer=6, n_head=8, n_embd=128, model_type='GPT', use_pred_state=True, + state_tokenizer='resnet18', train_mode='loc', pretrained_model_path='', pretrained_encoder_path='', loss='MSE', + map_decoder='deconv', map_recon_dim=64) + loc_model = GPT(loc_mconf, device) + # map_model=torch.nn.DataParallel(map_model) + checkpoint = torch.load(saved_loc_model_path, map_location=device) + + # old code for loading model + # loc_model.load_state_dict(checkpoint['state_dict']) + # new code for loading mode + new_checkpoint = OrderedDict() + for key in checkpoint['state_dict'].keys(): + new_checkpoint[key.split("module.",1)[1]] = checkpoint['state_dict'][key] + loc_model.load_state_dict(new_checkpoint) + + loc_model.eval() + loc_model.to(device) + self.loc_model = loc_model + + + self.q_scans = Queue(maxsize = self.clip_len) + self.q_actions = Queue(maxsize = self.clip_len) + self.q_pos = Queue(maxsize = self.clip_len) + for i in range(self.clip_len): + self.q_actions.put(self.default_angle) + self.last_action = self.default_angle + self.compute_network = False + self.compute_network_loc = False + self.has_loc_anchor = False + self.did_reset = False + + # parameters for model evaluation + self.reset_counter = 0 + self.last_reset_time = time.time() + self.distance_so_far = 0.0 + self.time_so_far = 0.0 + self.file_name = os.path.join(self.out_path,'info.csv') + + # set timer callbacks for visualization + rate_map_display = 1.0 + rate_loc_display = 10 + # self.map_viz_timer = rospy.Timer(rospy.Duration(1.0 / rate_map_display), self.map_viz_cb) + self.map_viz_loc = rospy.Timer(rospy.Duration(1.0 / rate_loc_display), self.loc_viz_cb) + + + def start(self): + self.logger.info("Starting RHController") + self.setup_pub_sub() + self.rhctrl = self.load_controller() + self.find_allowable_pts() # gets the allowed halton points from the map + + self.ready_event.set() + + rate_hz = 50 + rate = rospy.Rate(rate_hz) + self.logger.info("Initialized") + + # set initial pose for the car in the very first time in an allowable region + self.send_initial_pose() + # self.send_initial_pose_12f() + self.time_started = rospy.Time.now() + + # wait until we actually have a car pose + rospy.loginfo("Waiting to receive pose") + while not rospy.is_shutdown() and self.inferred_pose is None: + pass + rospy.loginfo("Vehicle pose received") + + while not rospy.is_shutdown() and self.run: + + # check if we should reset the vehicle if crashed + if self.check_reset(rate_hz): + rospy.loginfo("Resetting the car's position") + + # publish next action + if self.compute_network: + # don't have to run the network at all times, only when scans change and scans are full + self.last_action = self.apply_network() + self.act_lock.acquire() + self.q_actions.get() # remove the oldest action from the queue + self.q_actions.put(self.last_action) + self.act_lock.release() + # rospy.loginfo("Applied network: "+str(self.last_action)) + self.compute_network = False + + self.publish_vel_marker() + self.publish_traj(self.default_speed, self.last_action) + + # if map is not None: + rate.sleep() + + def map_viz_cb(self, timer): + self.pos_lock.acquire() + pos_queue_list = list(self.q_pos.queue) + pos_size = len(pos_queue_list) + self.pos_lock.release() + if pos_size==16: + x_imgs, x_act, t = self.prepare_model_inputs() + start = time.time() + # with torch.set_grad_enabled(False): + with torch.inference_mode(): + self.map_recon, loss = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) + finished_map_network = time.time() + rospy.loginfo("map network delay: "+str(finished_map_network-start)) + pose_mid = pos_queue_list[int(pos_size/2) -1] + # publish the GT pose of the map center + self.pose_marker_pub.publish(self.create_position_marker(pose_mid)) + # publish the map itself + self.map_marker_pub.publish(self.create_map_marker(pose_mid)) + + def loc_viz_cb_incremental(self, timer): + if self.compute_network_loc is False: + return + self.pos_lock.acquire() + pos_queue_list = list(self.q_pos.queue) + pos_size = len(pos_queue_list) + self.pos_lock.release() + + # create anchor pose for localization + if time.time()-self.time_sent_reset>3.0 and self.did_reset is True: + self.loc_counter = 0 + self.pose_anchor = copy.deepcopy(self.curr_pose) + self.current_pose = copy.deepcopy(self.pose_anchor) + self.has_loc_anchor = True + self.did_reset = False + + if self.loc_counter>=1 and self.has_loc_anchor is True: + self.loc_counter = 0 + x_imgs, x_act, t = self.prepare_model_inputs() + start = time.time() + # with torch.set_grad_enabled(False): + with torch.inference_mode(): + pose_preds, loss = self.loc_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) + finished_loc_network = time.time() + rospy.loginfo("loc network delay: "+str(finished_loc_network-start)) + # publish anchor pose of the map center + self.loc_anchor_pose_marker_pub.publish(self.create_position_marker(self.pose_anchor, color=[0,0,1,1])) + # publish the current accumulated pose + pose_pred = pose_preds[0,self.clip_len-1,:].cpu().numpy() + self.current_pose = self.sum_stamped_poses(self.current_pose, pose_pred) + self.loc_current_pose_marker_pub.publish(self.create_position_marker(self.current_pose, color=[1,0,0,1])) + + def loc_viz_cb(self, timer): + if self.compute_network_loc is False: + return + self.pos_lock.acquire() + pos_queue_list = list(self.q_pos.queue) + pos_size = len(pos_queue_list) + self.pos_lock.release() + + # create anchor pose for localization + if time.time()-self.time_sent_reset>3.0 and self.did_reset is True: + self.loc_counter = 0 + self.has_loc_anchor = False + self.did_reset = False + rospy.logwarn("Resetting the loc position") + + # set the anchor and equal to the first reference when we count 16 scans after reset + if self.loc_counter>=16 and self.has_loc_anchor is False: + rospy.logwarn("Setting the loc anchor position when completed 16 scans") + self.pose_anchor = copy.deepcopy(self.curr_pose) + self.current_frame = copy.deepcopy(self.pose_anchor) + self.has_loc_anchor = True + + if self.loc_counter>16 and self.has_loc_anchor is True and self.compute_network_loc is True: + x_imgs, x_act, t = self.prepare_model_inputs() + start = time.time() + # with torch.set_grad_enabled(False): + with torch.inference_mode(): + pose_preds, loss = self.loc_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) + finished_loc_network = time.time() + rospy.loginfo("loc network delay: "+str(finished_loc_network-start)) + # publish anchor pose of the map center + self.loc_anchor_pose_marker_pub.publish(self.create_position_marker(self.pose_anchor, color=[0,0,1,1])) + # publish the current accumulated pose + delta_pose_pred = pose_preds[0,self.clip_len-1,:].cpu().numpy() + # calculate the change in coordinates + self.current_frame = self.transform_poses(self.current_frame, delta_pose_pred) + self.loc_current_pose_marker_pub.publish(self.create_position_marker(self.current_frame, color=[1,0,0,1])) + + # if pos_size==16 and self.has_loc_anchor is True: + # x_imgs, x_act, t = self.prepare_model_inputs() + # start = time.time() + # with torch.set_grad_enabled(False): + # pose_preds, loss = self.loc_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None) + # finished_loc_network = time.time() + # rospy.loginfo("loc network delay: "+str(finished_loc_network-start)) + # # publish anchor pose of the map center + # self.loc_anchor_pose_marker_pub.publish(self.create_position_marker(self.pose_anchor, color=[0,0,1,1])) + # # publish the current accumulated pose + # pose_f = pose_preds[0,self.clip_len-1,:].cpu().numpy() + # pose_semi_f = pose_preds[0,self.clip_len-2,:].cpu().numpy() + # self.current_pose = self.calc_stamped_poses(self.current_pose, pose_f, pose_semi_f) + # self.loc_current_pose_marker_pub.publish(self.create_position_marker(self.current_pose, color=[1,0,0,1])) + + def transform_poses(self, current_pose, delta_pose_pred): + # elements of homogeneous matrix expressing point from local frame into world frame coords + current_angle = utilss.rosquaternion_to_angle(current_pose.pose.orientation) + R = np.array([[np.cos(current_angle),-np.sin(current_angle)], + [np.sin(current_angle),np.cos(current_angle)]]) + t = np.array([[current_pose.pose.position.x], + [current_pose.pose.position.y]]) + T = np.array([[R[0,0],R[0,1],t[0,0]], + [R[1,0],R[1,1],t[1,0]], + [0,0,1]]) + # now transform the position of the next point from local to world frame + pose_local = np.array([[delta_pose_pred[0]], + [delta_pose_pred[1]], + [1]]) + pose_world = np.matmul(T, pose_local) + current_pose.pose.position.x = pose_world[0,0] + current_pose.pose.position.y = pose_world[1,0] + current_angle += delta_pose_pred[2] + current_pose.pose.orientation = utilss.angle_to_rosquaternion(current_angle) + return current_pose + + def calc_stamped_poses(self, pos1, pose_f, pose_semi_f): + pos1.pose.position.x = pos1.pose.position.x + pose_f[0] - pose_semi_f[0] + pos1.pose.position.y = pos1.pose.position.y + pose_f[1] - pose_semi_f[1] + angle_f = math.atan2(pose_f[3], pose_f[2]) + angle_semi_f = math.atan2(pose_semi_f[3], pose_semi_f[2]) + orig_angle = utilss.rosquaternion_to_angle(pos1.pose.orientation) + pos1.pose.orientation = utilss.angle_to_rosquaternion(orig_angle+angle_f-angle_semi_f) + return pos1 + + def sum_stamped_poses(self, pos1, pose_pred): + pos1.pose.position.x += pose_pred[0] + pos1.pose.position.y += pose_pred[1] + angle = math.atan2(pose_pred[3], pose_pred[2]) + orig_angle = utilss.rosquaternion_to_angle(pos1.pose.orientation) + pos1.pose.orientation = utilss.angle_to_rosquaternion(orig_angle+angle) + return pos1 + + def create_map_marker(self, pose_stamped): + + start = time.time() + marker = Marker() + marker.header.frame_id = "/map" + marker.header.stamp = rospy.Time.now() + marker.type = 8 # points + marker.id = 0 + + # Set the scale of the marker + marker.scale.x = 0.1 + marker.scale.y = 0.1 + + map_recon = self.map_recon.cpu().numpy()[0,0,:] + [w,h] = map_recon.shape + + m_per_px = 12.0/64.0 + + # iterate over all pixels from the image to create the map in points + # only do this for the very first time. then skip because rel pos is the same + if self.points_viz_list is None: + self.points_viz_list = [] + for i in range(w): + for j in range(h): + p = Point() + p.x = +6.0 - i*m_per_px + p.y = +6.0 - j*m_per_px + p.z = 0.0 + self.points_viz_list.append(p) + marker.points = self.points_viz_list + + finished_points = time.time() + rospy.loginfo("points delay: "+str(finished_points-start)) + + # loop to figure out the colors for each point + for i in range(w): + for j in range(h): + cell_val = map_recon[i,j] + if cell_val < 0.2: + alpha = 0.0 + else: + alpha = 0.7 + color = ColorRGBA(r=0.0, g=min(max(cell_val,0.0),1.0), b=0.0, a=alpha) + marker.colors.append(color) + + finished_colors = time.time() + rospy.loginfo("color delay: "+str(finished_colors-finished_points)) + + # Set the pose of the marker + marker.pose = pose_stamped.pose + return marker + + def create_position_marker(self, pose_stamped, color=[0,1,0,1]): + marker = Marker() + marker.header.frame_id = "/map" + marker.header.stamp = rospy.Time.now() + marker.type = 0 # arrow + marker.id = 0 + + # Set the scale of the marker + marker.scale.x = 1 + marker.scale.y = 0.1 + marker.scale.z = 0.1 + + # Set the color + marker.color.r = color[0] + marker.color.g = color[1] + marker.color.b = color[2] + marker.color.a = color[3] + + # Set the pose of the marker + marker.pose = pose_stamped.pose + return marker + + + def publish_vel_marker(self): + marker = Marker() + marker.header.frame_id = "/car/base_link" + marker.header.stamp = rospy.Time.now() + marker.type = 0 # arrow + marker.id = 0 + + # Set the scale of the marker + marker.scale.x = 1 + marker.scale.y = 0.1 + marker.scale.z = 0.1 + + # Set the color + marker.color.r = 1.0 + marker.color.g = 0.0 + marker.color.b = 0.0 + marker.color.a = 1.0 + + # set the first point + # point_start = Point() + # point_start.x = point_start.y = point_start.z = 0.0 + # marker.points.append(point_start) + + # l = 5.0 + # point_end = Point() + # point_end.x = l*np.cos(self.last_action) + # point_end.y = l*np.sin(self.last_action) + # point_end.z = 0.0 + # marker.points.append(point_end) + + # Set the pose of the marker + marker.pose.position.x = 0.32 + marker.pose.position.y = 0 + marker.pose.position.z = 0 + marker.pose.orientation = utilss.angle_to_rosquaternion(self.last_action) + # marker.pose.orientation.x = 0.0 + # marker.pose.orientation.y = 0.0 + # marker.pose.orientation.z = 0.0 + # marker.pose.orientation.w = 1.0 + + self.vel_marker_pub.publish(marker) + + + + def apply_network(self): + x_imgs, x_act, t = self.prepare_model_inputs() + start = time.time() + # with torch.set_grad_enabled(False): + with torch.inference_mode(): + # action_pred = 0.0 + action_pred, loss = self.model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) + finished_action_network = time.time() + rospy.loginfo("action network delay: "+str(finished_action_network-start)) + action_pred = action_pred[0,self.clip_len-1,0].cpu().flatten().item() + # if self.use_map: + # map_pred, loss = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None) + # finished_map_network = time.time() + # rospy.loginfo("map network delay: "+str(finished_map_network-finished_action_network)) + finished_network = time.time() + # rospy.loginfo("network delay: "+str(finished_network-finish_processing)) + + # de-normalize + action_pred = pre.denorm_angle(action_pred) + return action_pred + # if self.use_map: + # return action_pred + # else: + # return action_pred + + def prepare_model_inputs(self): + start = time.time() + # organize the scan input + x_imgs = torch.zeros(1,self.clip_len,self.nx,self.ny) + x_act = torch.zeros(1,self.clip_len) + + self.scan_lock.acquire() + queue_list = list(self.q_scans.queue) + queue_size = self.q_scans.qsize() + self.scan_lock.release() + + idx = 0 + for img in queue_list: + x_imgs[0,idx,:] = torch.tensor(img) + idx+=1 + idx = 0 + self.act_lock.acquire() + for act in self.q_actions.queue: + x_act[0,idx] = torch.tensor(act) + idx+=1 + self.act_lock.release() + + x_imgs = x_imgs.contiguous().view(1, self.clip_len, 200*200) + x_imgs = x_imgs.to(self.device) + + x_act = x_act.view(1, self.clip_len , 1) + x_act = x_act.to(self.device) + + t = np.ones((1, 1, 1), dtype=int) * 7 + t = torch.tensor(t) + t = t.to(self.device) + + finish_processing = time.time() + # rospy.loginfo("processing delay: "+str(finish_processing-start)) + return x_imgs, x_act, t + + def check_reset(self, rate_hz): + # condition if the car gets stuck + if self.inferred_pose_prev() is not None and self.time_started is not None and self._time_of_inferred_pose is not None and self._time_of_inferred_pose_prev is not None: + # calculate distance traveled + delta_dist = np.linalg.norm(np.asarray(self.inferred_pose())-np.asarray(self.inferred_pose_prev())) + v = 2.0 # default value + if delta_dist < 0.5: + delta_time_poses = (self._time_of_inferred_pose-self._time_of_inferred_pose_prev).to_sec() + self.distance_so_far += delta_dist + self.time_so_far += delta_time_poses + # look at speed and termination condition + v = delta_dist / delta_time_poses + # print('v = {}'.format(v)) + if v < 0.05 and rospy.Time.now().to_sec() - self.time_started.to_sec() > 1.0: + # this means that the car was supposed to follow a traj, but velocity is too low bc it's stuck + # first we reset the car pose + self.reset_counter +=1 + if self.reset_counter > 5 : + # save distance data to file and reset distance + delta_time = time.time() - self.last_reset_time + print("Distance: {} | Time: {} | Time so far: {}".format(self.distance_so_far, delta_time, self.time_so_far)) + with open(self.file_name,'a') as fd: + fd.write(str(self.distance_so_far)+','+str(self.time_so_far)+'\n') + self.send_initial_pose() + # self.send_initial_pose_12f() + rospy.loginfo("Got stuck, resetting pose of the car to default value") + msg = String() + msg.data = "got stuck" + self.expr_at_goal.publish(msg) + self.reset_counter = 0 + # new_line = np.array([self.distance_so_far, delta_time]) + # self.out_file = open(self.file_name,'ab') + # np.savetxt(self.out_file, new_line, delimiter=',') + # self.out_file.close() + self.distance_so_far = 0.0 + self.time_so_far = 0.0 + self.last_reset_time = time.time() + return True + else: + return False + + def send_initial_pose(self): + # sample a initial pose for the car based on the valid samples + hp_world_valid = self.hp_world[self.hp_zerocost_ids] + new_pos_idx = np.random.randint(0, hp_world_valid.shape[0]) + msg = PoseWithCovarianceStamped() + msg.header.stamp = rospy.Time.now() + msg.header.frame_id = "map" + # msg.pose.pose.position.x = hp_world_valid[new_pos_idx][0] + # msg.pose.pose.position.y = hp_world_valid[new_pos_idx][1] + # msg.pose.pose.position.z = 0.0 + # quat = utilss.angle_to_rosquaternion(hp_world_valid[new_pos_idx][1]) + msg.pose.pose.position.x = 4.12211 + (np.random.rand()-0.5)*2.0*0.5 + msg.pose.pose.position.y = -7.49623 + (np.random.rand()-0.5)*2.0*0.5 + msg.pose.pose.position.z = 0.0 + quat = utilss.angle_to_rosquaternion(np.radians(68 + (np.random.rand()-0.5)*2.0*360)) # 360 instead of zero at the end + msg.pose.pose.orientation = quat + + self.did_reset = True + self.time_sent_reset = time.time() + + # # create anchor pose for localization + # self.pose_anchor = PoseStamped() + # self.pose_anchor.header = msg.header + # self.pose_anchor.pose = msg.pose.pose + # self.current_pose = copy.deepcopy(self.pose_anchor) + # self.has_loc_anchor = True + + self.pose_reset.publish(msg) + + def send_initial_pose_12f(self): + # sample a initial pose for the car based on the valid samples + hp_world_valid = self.hp_world[self.hp_zerocost_ids] + new_pos_idx = np.random.randint(0, hp_world_valid.shape[0]) + msg = PoseWithCovarianceStamped() + msg.header.stamp = rospy.Time.now() + msg.header.frame_id = "map" + # msg.pose.pose.position.x = hp_world_valid[new_pos_idx][0] + # msg.pose.pose.position.y = hp_world_valid[new_pos_idx][1] + # msg.pose.pose.position.z = 0.0 + # quat = utilss.angle_to_rosquaternion(hp_world_valid[new_pos_idx][1]) + msg.pose.pose.position.x = -3.3559 + (np.random.rand()-0.5)*2.0*0.5 + msg.pose.pose.position.y = 4.511 + (np.random.rand()-0.5)*2.0*0.5 + msg.pose.pose.position.z = 0.0 + quat = utilss.angle_to_rosquaternion(np.radians(-83.115 + (np.random.rand()-0.5)*2.0*360)) # 360 instead of zero at the end + msg.pose.pose.orientation = quat + + self.did_reset = True + self.time_sent_reset = time.time() + + self.pose_reset.publish(msg) + + def shutdown(self, signum, frame): + rospy.signal_shutdown("SIGINT recieved") + self.run = False + for ev in self.events: + ev.set() + + def process_scan(self, msg): + scan = np.zeros((721), dtype=np.float) + scan[0] = msg.header.stamp.to_sec() + scan[1:] = msg.ranges + original_points, sensor_origins, time_stamps, pc_range, voxel_size, lo_occupied, lo_free = pre.load_params(scan) + vis_mat, nx, ny = pre.compute_bev_image(original_points, sensor_origins, time_stamps, pc_range, voxel_size) + if self.nx is None: + self.nx = nx + self.ny = ny + return vis_mat + + def cb_scan(self, msg): + + # remove element from position queue: + self.pos_lock.acquire() + if self.q_pos.full(): + self.q_pos.get() # remove the oldest element, will be replaced next + self.pos_lock.release() + + # add new vehicle position + self.pos_lock.acquire() + if self.curr_pose is None: + self.pos_lock.release() + # exist the callback if there is no current pose: will only happen at the very beginning + return + else: + self.q_pos.put(self.curr_pose) + self.pos_lock.release() + + # remove oldest element if the queue is already full + self.scan_lock.acquire() + if self.q_scans.full(): + self.compute_network = True # start running the network in the main loop from now on + self.compute_network_loc = True + self.loc_counter += 1 + self.q_scans.get() # remove the oldest element, will be replaced next + self.scan_lock.release() + + # add new processed scan + tmp = self.process_scan(msg) + self.scan_lock.acquire() + self.q_scans.put(tmp) # store matrices from 0-1 with the scans + self.scan_lock.release() + + + def setup_pub_sub(self): + rospy.Service("~reset/soft", SrvEmpty, self.srv_reset_soft) + rospy.Service("~reset/hard", SrvEmpty, self.srv_reset_hard) + + car_name = self.params.get_str("car_name", default="car") + + rospy.Subscriber( + "/" + car_name + "/" + 'scan', + LaserScan, + self.cb_scan, + queue_size=10, + ) + + rospy.Subscriber( + "/" + car_name + "/" + rospy.get_param("~inferred_pose_t"), + PoseStamped, + self.cb_pose, + queue_size=10, + ) + + self.rp_ctrls = rospy.Publisher( + "/" + + car_name + + "/" + + self.params.get_str( + "ctrl_topic", default="mux/ackermann_cmd_mux/input/navigation" + ), + AckermannDriveStamped, + queue_size=2, + ) + + self.vel_marker_pub = rospy.Publisher("/model_action_marker", Marker, queue_size = 1) + + # markers for mapping visualization + self.pose_marker_pub = rospy.Publisher("/pose_marker", Marker, queue_size = 1) + self.map_marker_pub = rospy.Publisher("/map_marker", Marker, queue_size = 1) + + # markers for localization visualization + self.loc_anchor_pose_marker_pub = rospy.Publisher("/loc_anchor_pose_marker", Marker, queue_size = 1) + self.loc_current_pose_marker_pub = rospy.Publisher("/loc_current_pose_marker", Marker, queue_size = 1) + + self.pose_reset = rospy.Publisher("/initialpose", PoseWithCovarianceStamped, queue_size=1) + + traj_chosen_t = self.params.get_str("traj_chosen_topic", default="~traj_chosen") + self.traj_chosen_pub = rospy.Publisher(traj_chosen_t, Marker, queue_size=10) + + # For the experiment framework, need indicators to listen on + self.expr_at_goal = rospy.Publisher("experiments/finished", String, queue_size=1) + + # to publish the new goal, for visualization + self.goal_pub = rospy.Publisher("~goal", Marker, queue_size=10) + + def srv_reset_hard(self, msg): + """ + Hard reset does a complete reload of the controller + """ + rospy.loginfo("Start hard reset") + self.reset_lock.acquire() + self.load_controller() + self.goal_event.clear() + self.reset_lock.release() + rospy.loginfo("End hard reset") + return [] + + def srv_reset_soft(self, msg): + """ + Soft reset only resets soft state (like tensors). No dependencies or maps + are reloaded + """ + rospy.loginfo("Start soft reset") + self.reset_lock.acquire() + self.rhctrl.reset() + self.goal_event.clear() + self.reset_lock.release() + rospy.loginfo("End soft reset") + return [] + + def find_allowable_pts(self): + self.hp_map, self.hp_world = self.rhctrl.cost.value_fn._get_halton_pts() + self.hp_zerocost_ids = np.zeros(self.hp_map.shape[0], dtype=bool) + for i, pts in enumerate(self.hp_map): + pts = pts.astype(np.int) + if int(pts[0])3.0 and self.did_reset is True: @@ -360,7 +356,7 @@ def loc_viz_cb(self, timer): start = time.time() # with torch.set_grad_enabled(False): with torch.inference_mode(): - pose_preds, loss = self.loc_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) + pose_preds, _, _, _, _ = self.loc_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) finished_loc_network = time.time() rospy.loginfo("loc network delay: "+str(finished_loc_network-start)) # publish anchor pose of the map center @@ -371,6 +367,7 @@ def loc_viz_cb(self, timer): # calculate the change in coordinates self.current_frame = self.transform_poses(self.current_frame, delta_pose_pred) self.loc_current_pose_marker_pub.publish(self.create_position_marker(self.current_frame, color=[1,0,0,1])) + self.compute_network_loc = False # if pos_size==16 and self.has_loc_anchor is True: # x_imgs, x_act, t = self.prepare_model_inputs() @@ -550,10 +547,10 @@ def apply_network(self): # with torch.set_grad_enabled(False): with torch.inference_mode(): # action_pred = 0.0 - action_pred, loss = self.model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) + action_pred, _, _, _ = self.model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None, compute_loss=False) finished_action_network = time.time() rospy.loginfo("action network delay: "+str(finished_action_network-start)) - action_pred = action_pred[0,self.clip_len-1,0].cpu().flatten().item() + action_pred = action_pred[0,-1,0].cpu().flatten().item() # if self.use_map: # map_pred, loss = self.map_model(states=x_imgs, actions=x_act, targets=x_act, gt_map=None, timesteps=t, poses=None) # finished_map_network = time.time() @@ -598,9 +595,12 @@ def prepare_model_inputs(self): x_act = x_act.view(1, self.clip_len , 1) x_act = x_act.to(self.device) - t = np.ones((1, 1, 1), dtype=int) * 7 - t = torch.tensor(t) - t = t.to(self.device) + t = torch.arange(0, self.clip_len).view(1,-1).to(self.device) + t = t.repeat(1,1) + + # t = np.ones((1, 1, 1), dtype=int) * 7 + # t = torch.tensor(t) + # t = t.to(self.device) finish_processing = time.time() # rospy.loginfo("processing delay: "+str(finish_processing-start)) @@ -617,7 +617,8 @@ def check_reset(self, rate_hz): self.distance_so_far += delta_dist self.time_so_far += delta_time_poses # look at speed and termination condition - v = delta_dist / delta_time_poses + if delta_time_poses > 0.001: + v = delta_dist / delta_time_poses # print('v = {}'.format(v)) if v < 0.05 and rospy.Time.now().to_sec() - self.time_started.to_sec() > 1.0: # this means that the car was supposed to follow a traj, but velocity is too low bc it's stuck @@ -734,9 +735,6 @@ def cb_scan(self, msg): # remove oldest element if the queue is already full self.scan_lock.acquire() if self.q_scans.full(): - self.compute_network = True # start running the network in the main loop from now on - self.compute_network_loc = True - self.loc_counter += 1 self.q_scans.get() # remove the oldest element, will be replaced next self.scan_lock.release() @@ -745,6 +743,11 @@ def cb_scan(self, msg): self.scan_lock.acquire() self.q_scans.put(tmp) self.scan_lock.release() + + # control flags for other processes are activated now that queues have been updated + self.compute_network = True # start running the network in the main loop from now on + self.compute_network_loc = True + self.loc_counter += 1 def setup_pub_sub(self):