Skip to content

Commit

Permalink
Fix bugs from interactive testing
Browse files Browse the repository at this point in the history
  • Loading branch information
sea-bass committed Nov 9, 2024
1 parent 5238b4f commit 0fe5a7a
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 60 deletions.
3 changes: 2 additions & 1 deletion pyrobosim/pyrobosim/core/robot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""" Defines a robot which operates in a world. """

import logging
import time
import numpy as np

Expand Down Expand Up @@ -1020,6 +1019,8 @@ def to_dict(self):
robot_dict = {
"name": self.name,
"radius": self.radius,
"height": self.height,
"color": self.color,
"pose": pose.to_dict(),
"max_linear_velocity": float(self.dynamics.vel_limits[0]),
"max_angular_velocity": float(self.dynamics.vel_limits[-1]),
Expand Down
2 changes: 1 addition & 1 deletion pyrobosim/pyrobosim/core/world.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
""" Main file containing the core world modeling tools. """

import itertools
import logging
import numpy as np

from .hallway import Hallway
Expand Down Expand Up @@ -43,6 +42,7 @@ def __init__(
self.name = name
self.wall_height = wall_height
self.source_yaml = None
self.source_yaml_file = None
self.logger = create_logger(self.name)

# Connected apps
Expand Down
86 changes: 34 additions & 52 deletions pyrobosim/pyrobosim/core/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Utilities to create worlds from YAML files. """

import numpy as np
import copy
import os
import yaml

Expand Down Expand Up @@ -52,21 +52,13 @@ def from_file(self, filename):
world_dict = yaml.load(file, Loader=yaml.FullLoader)
(world_dir, _) = os.path.split(filename)
world = self.from_yaml(world_dict, world_dir)
world.source_yaml = world_dict
world.source_yaml_file = filename
return world

def create_world(self):
"""Creates an initial world with the specified global parameters."""
if "params" in self.data:
params = self.data["params"]
self.world = World(
name=params.get("name", "world"),
inflation_radius=params.get("inflation_radius", 0.0),
object_radius=params.get("object_radius", 0.0),
wall_height=params.get("wall_height", 2.0),
)
else:
self.world = World()
params = self.data.get("params", {})
self.world = World(**params)

# Set the location/object metadata
metadata = self.data.get("metadata")
Expand All @@ -88,12 +80,12 @@ def create_world(self):
def add_rooms(self):
"""Add rooms to the world."""
for room_data in self.data.get("rooms", []):
if "nav_poses" in room_data:
room_data["nav_poses"] = [
Pose.construct(p) for p in room_data["nav_poses"]
room_args = copy.deepcopy(room_data)
if "nav_poses" in room_args:
room_args["nav_poses"] = [
Pose.construct(p) for p in room_args["nav_poses"]
]

self.world.add_room(**room_data)
self.world.add_room(**room_args)

def add_hallways(self):
"""Add hallways connecting rooms to the world."""
Expand All @@ -103,55 +95,42 @@ def add_hallways(self):
def add_locations(self):
"""Add locations for object spawning to the world."""
for loc_data in self.data.get("locations", []):
loc_data["pose"] = Pose.construct(loc_data["pose"])

self.world.add_location(**loc_data)
loc_args = copy.deepcopy(loc_data)
loc_args["pose"] = Pose.construct(loc_args["pose"])
self.world.add_location(**loc_args)

def add_objects(self):
"""Add objects to the world."""
if "objects" not in self.data:
return

for obj_data in self.data.get("objects", []):
if "pose" in obj_data:
obj_data["pose"] = Pose.construct(obj_data["pose"])
self.world.add_object(**obj_data)
obj_args = copy.deepcopy(obj_data)
if "pose" in obj_args:
obj_args["pose"] = Pose.construct(obj_args["pose"])
self.world.add_object(**obj_args)

def add_robots(self):
"""Add robots to the world."""
if "robots" not in self.data:
return

for id, robot_data in enumerate(self.data["robots"]):
for id, robot_data in enumerate(self.data.get("robots", [])):
# Create the robot
robot_name = robot_data.get("name", f"robot{id}")
robot_color = robot_data.get("color", (0.8, 0.0, 0.8))
robot = Robot(
name=robot_name,
radius=robot_data.get("radius", 0.0),
height=robot_data.get("height", 0.0),
color=robot_color,
max_linear_velocity=robot_data.get("max_linear_velocity", np.inf),
max_angular_velocity=robot_data.get("max_angular_velocity", np.inf),
max_linear_acceleration=robot_data.get(
"max_linear_acceleration", np.inf
),
max_angular_acceleration=robot_data.get(
"max_angular_acceleration", np.inf
),
path_planner=self.get_path_planner(robot_data),
path_executor=self.get_path_executor(robot_data),
grasp_generator=self.get_grasp_generator(robot_data),
partial_observability=robot_data.get("partial_observability", False),
action_execution_options=self.get_action_execution_options(robot_data),
initial_battery_level=robot_data.get("initial_battery_level", 100.0),
robot_args = copy.deepcopy(robot_data)
del robot_args["location"]
if "name" not in robot_args:
robot_args["name"] = f"robot{id}"
robot_args["path_planner"] = self.get_path_planner(robot_args)
robot_args["path_executor"] = self.get_path_executor(robot_args)
robot_args["grasp_generator"] = self.get_grasp_generator(robot_args)
robot_args["action_execution_options"] = self.get_action_execution_options(
robot_args
)
robot = Robot(**robot_args)

loc = robot_data["location"] if "location" in robot_data else None
if loc:
loc = robot_args.get("location")
if loc is not None:
loc = self.world.get_entity_by_name(loc)
if "pose" in robot_data:
pose = Pose.construct(robot_data["pose"])
if "pose" in robot_args:
pose = Pose.construct(robot_args["pose"])
else:
pose = None
self.world.add_robot(robot, loc=loc, pose=pose)
Expand All @@ -168,6 +147,7 @@ def get_path_planner(self, robot_data):

planner_class = get_planner_class(planner_type)
path_planner = planner_class(**planner_data)
del robot_data["path_planner"]
return path_planner

def get_path_executor(self, robot_data):
Expand All @@ -177,6 +157,7 @@ def get_path_executor(self, robot_data):

path_executor_data = robot_data["path_executor"].copy()
path_executor_type = path_executor_data["type"]
del robot_data["path_executor"]
del path_executor_data["type"]
if path_executor_type == "constant_velocity":
return ConstantVelocityExecutor(**path_executor_data)
Expand All @@ -198,6 +179,7 @@ def get_grasp_generator(self, robot_data):

grasp_params = robot_data["grasping"].copy()
grasp_gen_type = grasp_params["generator"]
del robot_data["grasping"]
del grasp_params["generator"]
if grasp_gen_type == "parallel_grasp":
grasp_properties = ParallelGraspProperties(**grasp_params)
Expand Down
7 changes: 5 additions & 2 deletions pyrobosim/pyrobosim/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def update_button_state(self):
self.cancel_action_button.setEnabled(False)
self.open_button.setEnabled(True)
self.close_button.setEnabled(True)
self.reset_world_button.setEnabled(has_source_file)
self.reset_world_button.setEnabled(True)
self.reset_path_planner_button.setEnabled(False)
self.rand_pose_button.setEnabled(False)

Expand Down Expand Up @@ -402,7 +402,10 @@ def on_reset_world_click(self):
for robot in self.world.robots:
ros_node.remove_robot_ros_interfaces(robot)

world = WorldYamlLoader().from_yaml(self.world.source_yaml)
if self.world.source_yaml_file is not None:
world = WorldYamlLoader().from_file(self.world.source_yaml_file)
else:
world = WorldYamlLoader().from_yaml(self.world.source_yaml)
self.set_world(world)

# Start up the new robots' ROS interfaces.
Expand Down
9 changes: 5 additions & 4 deletions pyrobosim/pyrobosim/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def create_logger(name, level=logging.INFO):
logger.setLevel(level)

# TODO: Consider configuring console vs. file logging at some point.
log_formatter = logging.Formatter("[%(name)s] %(levelname)s: %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
logger.addHandler(console_handler)
if not logger.hasHandlers(): # Prevents adding duplicate handlers
log_formatter = logging.Formatter("[%(name)s] %(levelname)s: %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
logger.addHandler(console_handler)

# Needed to propagate to unit tests via the caplog fixture.
logger.propagate = True
Expand Down

0 comments on commit 0fe5a7a

Please sign in to comment.