Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/extract traj #281

Merged
merged 9 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions scripts/dataset_generation/extract_binary_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import sys
import os
from pathlib import Path
import time
from tqdm import tqdm
import subprocess
import yaml

from tf_bag import BagTfTransformer
import rospy
import rosparam
from sensor_msgs.msg import Image, CameraInfo, CompressedImage
import rosbag

from postprocessing_tools_ros.merging import merge_bags_single, merge_bags_all

# from py_image_proc_cuda import ImageProcCuda
# from cv_bridge import CvBridge

from wild_visual_navigation import WVN_ROOT_DIR
from wild_visual_navigation.utils import perugia_dataset, ROOT_DIR

sys.path.append(f"{WVN_ROOT_DIR}/wild_visual_navigation_ros/scripts")
from wild_visual_navigation_node import WvnRosInterface

sys.path.append(f"{WVN_ROOT_DIR}/wild_visual_navigation_anymal/scripts")
from anymal_msg_converter_node import anymal_msg_callback

# We need to do the following
# 1. Debayering cam4 -> send via ros and wait for result ("correct params")
# 2. anymal_state_topic -> /wild_visual_navigation_node/robot_state
# 3. Feed into wild_visual_navigation_node ("correct params")
# # Iterate rosbags


def get_bag_info(rosbag_path: str) -> dict:
# This queries rosbag info using subprocess and get the YAML output to parse the topics
info_dict = yaml.safe_load(
subprocess.Popen(["rosbag", "info", "--yaml", rosbag_path], stdout=subprocess.PIPE).communicate()[0]
)
return info_dict


class BagTfTransformerWrapper:
def __init__(self, bag):
self.tf_listener = BagTfTransformer(bag)

def waitForTransform(self, parent_frame, child_frame, time, duration):
return self.tf_listener.waitForTransform(parent_frame, child_frame, time)

def lookupTransform(self, parent_frame, child_frame, time):
try:
return self.tf_listener.lookupTransform(parent_frame, child_frame, time)
except:
return (None, None)


def do(n, dry_run):
d = perugia_dataset[n]

if bool(dry_run):
print(d)
return

s = os.path.join(ROOT_DIR, d["name"])

valid_topics = ["/state_estimator/anymal_state", "/wide_angle_camera_front/img_out"]

rosbags = ["/home/rschmid/RosBags/6_proc/images.bag",
"/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_0.bag",
"/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_1.bag"]

output_bag_wvn = s + "_wvn.bag"
output_bag_tf = s + "_tf.bag"

if not os.path.exists(output_bag_tf):
total_included_count, total_skipped_count = merge_bags_single(
input_bag=rosbags, output_bag=output_bag_tf, topics="/tf /tf_static", verbose=True
)
if not os.path.exists(output_bag_wvn):
total_included_count, total_skipped_count = merge_bags_single(
input_bag=rosbags, output_bag=output_bag_wvn, topics=" ".join(valid_topics), verbose=True
)

# Setup WVN node
rospy.init_node("wild_visual_navigation_node")

mission = s.split("/")[-1]

running_store_folder = f"/home/rschmid/RosBags/output/{mission}"

if os.path.exists(running_store_folder):
print("Folder already exists, but proceeding!")
# return

rosparam.set_param("wild_visual_navigation_node/mode", "extract_labels")
rosparam.set_param("wild_visual_navigation_node/running_store_folder", running_store_folder)

# for proprioceptive callback
state_msg_valid = False
desired_twist_msg_valid = False

wvn_ros_interface = WvnRosInterface()
print("-" * 80)

print("start loading tf")
tf_listener = BagTfTransformerWrapper(output_bag_tf)
wvn_ros_interface.setup_rosbag_replay(tf_listener)
print("done loading tf")

# Höngg new
info_msg = CameraInfo()
info_msg.height = 1080
info_msg.width = 1440
info_msg.distortion_model = "equidistant"
info_msg.D = [0.4316922809468283, 0.09279900476637248, -0.4010909691803734, 0.4756163338479413]
info_msg.K = [575.6050407221768, 0.0, 745.7312198525915, 0.0, 578.564849365178, 519.5207040671075, 0.0, 0.0, 1.0]
info_msg.P = [575.6050407221768, 0.0, 745.7312198525915, 0.0, 0.0, 578.564849365178, 519.5207040671075, 0.0, 0.0, 0.0, 1.0, 0.0]
info_msg.R = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]

rosbag_info_dict = get_bag_info(output_bag_wvn)
total_msgs = sum([x["messages"] for x in rosbag_info_dict["topics"] if x["topic"] in valid_topics])
total_time_img = 0
total_time_state = 0
n = 0

with rosbag.Bag(output_bag_wvn, "r") as bag:
if rospy.is_shutdown():
return

start_time = rospy.Time.from_sec(bag.get_start_time() + d["start"])
end_time = rospy.Time.from_sec(bag.get_start_time() + d["stop"])
with tqdm(
total=total_msgs,
desc="Total",
colour="green",
position=1,
bar_format="{desc:<13}{percentage:3.0f}%|{bar:20}{r_bar}",
) as pbar:
for (topic, msg, ts) in bag.read_messages(topics=None, start_time=start_time, end_time=end_time):

if rospy.is_shutdown():
return

pbar.update(1)
st = time.time()
if topic == "/state_estimator/anymal_state":
state_msg = anymal_msg_callback(msg, return_msg=True)
state_msg_valid = True

elif topic == "/wide_angle_camera_front/img_out":
image_msg = msg
# print("Received /wide_angle_camera_front/img_out")

info_msg.header = msg.header
camera_options = {}
camera_options['name'] = "wide_angle_camera_front"
camera_options["use_for_training"] = True

info_msg.header = msg.header
try:
wvn_ros_interface.image_callback(image_msg, info_msg, camera_options)
except Exception as e:
print("Bad image_callback", e)

total_time_img += time.time() - st
# print(f"image time: {total_time_img} , state time: {total_time_state}")
# print("add image")
if state_msg_valid:
try:
wvn_ros_interface.robot_state_callback(state_msg, None)
except Exception as e:
print("Bad robot_state callback ", e)

state_msg_valid = False
total_time_state += time.time() - st

print("Finished with converting the dataset")
rospy.signal_shutdown("stop the node")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=0, help="Store data")
parser.add_argument("--dry_run", type=int, default=0, help="Store data")
args = parser.parse_args()

do(args.n, args.dry_run)
46 changes: 46 additions & 0 deletions wild_visual_navigation/image_projector/image_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,52 @@ def project_and_render(

return self.masks, image_overlay, projected_points, valid_points

def project_and_render_on_map(
self, pose_base_in_world: torch.tensor, points: torch.tensor, colors: torch.tensor, map_resolution: float, map_size: int, image: torch.tensor = None
):
"""Projects the points and returns an image with the projection

Args:
points: (torch.Tensor, dtype=torch.float32, shape=(B, N, 3)): B batches, of N input points in 3D space
colors: (torch.Tensor, rtype=torch.float32, shape=(B, 3))

Returns:
out_img (torch.tensor, dtype=torch.int64): Image with projected points
"""

# self.masks = self.masks * 0.0
B = self.camera.batch_size
C = 3 # RGB channel output
H = self.camera.height.item()
W = self.camera.width.item()
self.masks = torch.zeros((B, C, H, W), dtype=torch.float32, device=self.camera.camera_matrix.device)
image_overlay = image

T_BW = pose_base_in_world.inverse()
# Convert from fixed to base frame
points_B = transform_points(T_BW, points)

# Remove z dimension
# TODO: project footprint on gravity aligned plane
flat_points = points_B[:, :, :-1]

# Shift to grid map coordinates
flat_points = flat_points / map_resolution + map_size / 2

# Fill the mask
self.masks = draw_convex_polygon(self.masks, flat_points, colors)

# Draw on image (if applies)
if image is not None:
if len(image.shape) != 4:
image = image[None]
image_overlay = draw_convex_polygon(image, flat_points, colors)

# Return torch masks
self.masks[self.masks == 0.0] = torch.nan

return self.masks, image_overlay

def resize_image(self, image: torch.tensor):
return self.image_crop(image)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
feature_type: str = "dino",
min_samples_for_training: int = 10,
vis_node_index: int = 10,
map_resolution: float = 0.1,
map_size: int = 128,
mode: bool = False,
extraction_store_folder=None,
anomaly_detection: bool = False,
Expand All @@ -62,6 +64,8 @@ def __init__(
self._scale_traversability = scale_traversability
self._params = params
self._scale_traversability_threshold = 0
self._map_resolution = map_resolution
self._map_size = map_size
self._anomaly_detection = anomaly_detection

if self._scale_traversability:
Expand Down Expand Up @@ -306,7 +310,7 @@ def add_mission_node(self, node: MissionNode, verbose: bool = False, update_feat

@accumulate_time
@torch.no_grad()
def add_proprio_node(self, pnode: ProprioceptionNode):
def add_proprio_node(self, pnode: ProprioceptionNode, projection_mode: str = "image"):
"""Adds a node to the proprioceptive graph to store proprioception

Args:
Expand Down Expand Up @@ -356,38 +360,48 @@ def add_proprio_node(self, pnode: ProprioceptionNode):
color = torch.ones((3,), device=self._device)

# New implementation
B = len(mission_nodes)
B = len(mission_nodes) # Number of mission nodes to project

# Prepare batches
K = torch.eye(4, device=self._device).repeat(B, 1, 1)
supervision_masks = torch.zeros(last_mission_node.supervision_mask.shape, device=self._device).repeat(
B, 1, 1, 1
)
pose_camera_in_world = torch.eye(4, device=self._device).repeat(B, 1, 1)
pose_base_in_world = torch.eye(4, device=self._device).repeat(B, 1, 1)

H = last_mission_node.image_projector.camera.height
W = last_mission_node.image_projector.camera.width
footprints = footprint.repeat(B, 1, 1)

for i, mnode in enumerate(mission_nodes):
K[i] = mnode.image_projector.camera.intrinsics
pose_camera_in_world[i] = mnode.pose_cam_in_world
pose_base_in_world[i] = mnode.pose_base_in_world

if (not hasattr(mnode, "supervision_mask")) or (mnode.supervision_mask is None):
continue
else:
supervision_masks[i] = mnode.supervision_mask
supervision_masks[i] = mnode.supervision_mask # Getting all the existing supervision masks

im = ImageProjector(K, H, W)
mask, _, _, _ = im.project_and_render(pose_camera_in_world, footprints, color)

map_resolution = self._map_resolution
map_size = self._map_size

if projection_mode == "image":
mask, _, _, _ = im.project_and_render(pose_camera_in_world, footprints, color) # Generating the new supervisiom mask to add
elif projection_mode == "map":
mask, _ = im.project_and_render_on_map(pose_base_in_world, footprints, color, map_resolution, map_size)

# Update traversability
mask = mask * pnode.traversability
supervision_masks = torch.fmin(supervision_masks, mask)
# mask = mask * pnode.traversability
supervision_masks = torch.fmin(supervision_masks, mask) # Adding the new mask to the supervision mask, using element-wise non-nan values

# Update supervision mask per node
for i, mnode in enumerate(mission_nodes):
mnode.supervision_mask = supervision_masks[i]
mnode.update_supervision_signal()
# mnode.update_supervision_signal() # Accumulate supervision signal, check if features are there

if self._mode == WVNMode.EXTRACT_LABELS:
p = os.path.join(
Expand All @@ -398,6 +412,7 @@ def add_proprio_node(self, pnode: ProprioceptionNode):
store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
torch.save(store, p)


# if self._anomaly_detection:
# # Visualize supervision mask
# store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
Expand Down
Loading