Skip to content

Commit

Permalink
update traffic light controller
Browse files Browse the repository at this point in the history
  • Loading branch information
DhlinV committed Jan 20, 2025
1 parent 2afd69f commit d52876a
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 21 deletions.
36 changes: 36 additions & 0 deletions metadrive/component/traffic_light/base_traffic_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from metadrive.constants import MetaDriveType, Semantics
from metadrive.engine.asset_loader import AssetLoader
from metadrive.utils.pg.utils import generate_static_box_physics_body
from metadrive.engine.logger import get_logger

logger = get_logger()


class BaseTrafficLight(BaseObject):
Expand Down Expand Up @@ -98,6 +101,13 @@ def set_green(self):
self._try_draw_line([3 / 255, 255 / 255, 3 / 255])
self.status = MetaDriveType.LIGHT_GREEN

if self.engine.global_config["use_traffic_light_controller"]:
try:
self.engine.dummy_vehicle.remove(self.dummy_vehicle.id)
self.engine.clear_objects([self.dummy_vehicle.id], force_destroy=True)
except:
pass

def set_red(self):
if self.render:
if self.current_light is not None:
Expand All @@ -107,6 +117,19 @@ def set_red(self):
self._try_draw_line([252 / 255, 0 / 255, 0 / 255])
self.status = MetaDriveType.LIGHT_RED

if self.engine.global_config["use_traffic_light_controller"]:
from metadrive.component.vehicle.vehicle_type import DummyVehicle
logger.info('Traffic light controller is enabled, the throttle_brake would be overwritten by red light')
dummy_vehicle = self.engine.spawn_object(DummyVehicle, vehicle_config={}, position=self.position)
dummy_vehicle.set_static(True)
if not hasattr(self.engine, "dummy_vehicle"):
self.engine.dummy_vehicle = [dummy_vehicle.id]
else:
self.engine.dummy_vehicle.append(dummy_vehicle.id)

# save object
self.dummy_vehicle = dummy_vehicle

def set_yellow(self):
if self.render:
if self.current_light is not None:
Expand All @@ -116,6 +139,13 @@ def set_yellow(self):
self._try_draw_line([252 / 255, 227 / 255, 3 / 255])
self.status = MetaDriveType.LIGHT_YELLOW

if self.engine.global_config["use_traffic_light_controller"]:
try:
self.engine.dummy_vehicle.remove(self.dummy_vehicle.id)
self.engine.clear_objects([self.dummy_vehicle.id], force_destroy=True)
except:
pass

def set_unknown(self):
if self.render:
if self.current_light is not None:
Expand All @@ -130,6 +160,12 @@ def destroy(self):
if self._draw_line:
self._line_drawer.reset()
self._line_drawer.removeNode()
if self.engine.global_config["use_traffic_light_controller"]:
try:
self.engine.dummy_vehicle.remove(self.dummy_vehicle.id)
self.engine.clear_objects([self.dummy_vehicle.id], force_destroy=True)
except:
pass

@property
def top_down_color(self):
Expand Down
33 changes: 32 additions & 1 deletion metadrive/component/vehicle/base_vehicle.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
position=None,
heading=None,
_calling_reset=True,
**kwargs
):
"""
This Vehicle Config is different from self.get_config(), and it is used to define which modules to use, and
Expand All @@ -136,6 +137,8 @@ def __init__(
# self.engine = get_engine()
BaseObject.__init__(self, name, random_seed, self.engine.global_config["vehicle_config"])
BaseVehicleState.__init__(self)
if 'render' in kwargs:
self.render = kwargs['render']
self.update_config(vehicle_config)
self.set_metadrive_type(MetaDriveType.VEHICLE)
use_special_color = self.config["use_special_color"]
Expand Down Expand Up @@ -164,6 +167,12 @@ def __init__(
# navigation module
self.navigation: Optional[NodeNetworkNavigation] = None

# traffic light controller
if 'use_traffic_light_controller' not in self.config:
self.config['use_traffic_light_controller'] = False
self.use_traffic_light_controller = self.config['use_traffic_light_controller']
logger.info(f"Traffic light status controls the status of vehicle: {self.use_traffic_light_controller}")

# state info
self.throttle_brake = 0.0
self.steering = 0
Expand Down Expand Up @@ -352,6 +361,8 @@ def reset(

self.spawn_place = position
# print("position:", position)
if heading is None:
heading = 0.0
self.set_heading_theta(heading)
self.set_static(False)
# self.set_wheel_friction(self.config["wheel_friction"])
Expand Down Expand Up @@ -474,6 +485,20 @@ def _set_action(self, action):
if action is None:
return
steering = action[0]
if self.use_traffic_light_controller:
from metadrive.policy.idm_policy import IDMPolicy
dummy_policy = IDMPolicy(self, np.random.randint(0, 1024))
dummy_policy.enable_lane_change = False
dummy_policy.TIME_WANTED = 0.2
dummy_policy.DISTANCE_WANTED = 2.0
idm_action, acc_front_vehicle = dummy_policy.dummy_act()
if acc_front_vehicle is not None:
if not hasattr(self.engine, "dummy_vehicle"):
self.engine.dummy_vehicle = []
if acc_front_vehicle.id in self.engine.dummy_vehicle:
dummy_pos = acc_front_vehicle.position
action[1] = idm_action[1]
#dself.add_policy(self.id, IDMPolicy, self, self.generate_seed())
self.throttle_brake = action[1]
self.steering = steering
self.system.setSteeringValue(self.steering * self.max_steering, 0)
Expand Down Expand Up @@ -502,7 +527,7 @@ def _apply_throttle_brake(self, throttle_brake):
else:
self.system.applyEngineForce(max_engine_force * throttle_brake, wheel_index)
else:
if self.enable_reverse:
if self.enable_reverse and not self.red_light:
self.system.applyEngineForce(max_engine_force * throttle_brake, wheel_index)
self.system.setBrake(0, wheel_index)
else:
Expand Down Expand Up @@ -773,10 +798,16 @@ def _state_check(self):
light = get_object_from_node(node)
if light.status == MetaDriveType.LIGHT_GREEN:
self.green_light = True
self.yellow_light = False
self.red_light = False
elif light.status == MetaDriveType.LIGHT_RED:
self.red_light = True
self.green_light = False
self.yellow_light = False
elif light.status == MetaDriveType.LIGHT_YELLOW:
self.yellow_light = True
self.green_light = False
self.red_light = False
elif light.status == MetaDriveType.LIGHT_UNKNOWN:
# unknown didn't add
continue
Expand Down
42 changes: 40 additions & 2 deletions metadrive/component/vehicle/vehicle_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from metadrive.engine.asset_loader import AssetLoader


# When using DefaultVehicle as traffic, please use this class.
class DefaultVehicle(BaseVehicle):
PARAMETER_SPACE = ParameterSpace(VehicleParameterSpace.DEFAULT_VEHICLE)
# LENGTH = 4.51
Expand Down Expand Up @@ -39,7 +40,43 @@ def WIDTH(self):
return self.DEFAULT_WIDTH


# When using DefaultVehicle as traffic, please use this class.
class DummyVehicle(BaseVehicle):
PARAMETER_SPACE = ParameterSpace(VehicleParameterSpace.DEFAULT_VEHICLE)
# LENGTH = 4.51
# WIDTH = 1.852
# HEIGHT = 1.19
TIRE_RADIUS = 0.313
TIRE_WIDTH = 0.25
MASS = 1100
LATERAL_TIRE_TO_CENTER = 0.815
FRONT_WHEELBASE = 1.05234
REAR_WHEELBASE = 1.4166
path = None #('ferra/vehicle.gltf', (0.1, 0.1, 1), (0, 0.075, 0.), (0, 0, 0)) # asset path, scale, offset, HPR

DEFAULT_LENGTH = 0.515 # meters
DEFAULT_HEIGHT = 1.19 # meters
DEFAULT_WIDTH = 1.852 # meters

render = False

def __init__(
self, vehicle_config=None, name=None, random_seed=None, position=None, heading=None, _calling_reset=True
):
super(DummyVehicle, self).__init__(
vehicle_config, name, random_seed, position, heading, _calling_reset, render=False
)

@property
def LENGTH(self):
return self.DEFAULT_LENGTH

@property
def HEIGHT(self):
return self.DEFAULT_HEIGHT

@property
def WIDTH(self):
return self.DEFAULT_WIDTH


class TrafficDefaultVehicle(DefaultVehicle):
Expand Down Expand Up @@ -442,7 +479,8 @@ def random_vehicle_type(np_random, p=None):
"static_default": StaticDefaultVehicle,
"varying_dynamics": VaryingDynamicsVehicle,
"varying_dynamics_bounding_box": VaryingDynamicsBoundingBoxVehicle,
"traffic_default": TrafficDefaultVehicle
"traffic_default": TrafficDefaultVehicle,
'dummy': DummyVehicle
}

vehicle_class_to_type = inv_map = {v: k for k, v in vehicle_type.items()}
7 changes: 6 additions & 1 deletion metadrive/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@
show_lane_line_detector=False,
# Whether to turn on vehicle light, only available when enabling render-pipeline
light=False,
# Whether to use traffic light controller for the vehicle
use_traffic_light_controller=False,
),

# ===== Sensors =====
Expand Down Expand Up @@ -272,7 +274,10 @@
force_reuse_object_name=False,

# ===== randomization =====
num_scenarios=1 # the number of scenarios in this environment
num_scenarios=1, # the number of scenarios in this environment

# ===== Traffic Light =====
use_traffic_light_controller=False,
)


Expand Down
34 changes: 34 additions & 0 deletions metadrive/policy/idm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,40 @@ def __init__(self, control_object, random_seed):
self.heading_pid = PIDController(1.7, 0.01, 3.5)
self.lateral_pid = PIDController(0.3, .002, 0.05)

def dummy_act(self, *args, **kwargs):
# concat lane
success = self.move_to_next_road()
all_objects = self.control_object.lidar.get_surrounding_objects(self.control_object)
try:
if success and self.enable_lane_change:
# perform lane change due to routing
acc_front_obj, acc_front_dist, steering_target_lane = self.lane_change_policy(all_objects)
else:
# can not find routing target lane
surrounding_objects = FrontBackObjects.get_find_front_back_objs(
all_objects,
self.routing_target_lane,
self.control_object.position,
max_distance=self.MAX_LONG_DIST
)
acc_front_obj = surrounding_objects.front_object()
acc_front_dist = surrounding_objects.front_min_distance()
steering_target_lane = self.routing_target_lane
except:
# error fallback
acc_front_obj = None
acc_front_dist = 5
steering_target_lane = self.routing_target_lane
# logging.warning("IDM bug! fall back")
# print("IDM bug! fall back")

# control by PID and IDM
steering = self.steering_control(steering_target_lane)
acc = self.acceleration(acc_front_obj, acc_front_dist)
action = [steering, acc]
self.action_info["action"] = action
return action, acc_front_obj

def act(self, *args, **kwargs):
# concat lane
success = self.move_to_next_road()
Expand Down
50 changes: 33 additions & 17 deletions metadrive/tests/test_component/test_traffic_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ def test_traffic_light(render=False, manual_control=False, debug=False):
"window_size": (1200, 800),
"vehicle_config": {
"enable_reverse": True,
"show_dest_mark": True
"show_dest_mark": True,
'use_traffic_light_controller': True
},
'use_traffic_light_controller': True
}
)
env.reset()
if debug:
env.engine.toggleDebug()
try:
# green
env.reset()
Expand All @@ -31,38 +35,50 @@ def test_traffic_light(render=False, manual_control=False, debug=False):
for s in range(1, 100):
env.step([0, 1])
if env.agent.green_light:
print('[Successfully detected] Green light')
test_success = True
break
print('Go through the green light')
# break
assert test_success
light.destroy()

# red test
env.reset()
light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0])
# dummy_vehicle = env.engine.spawn_object(DummyVehicle, vehicle_config={}, position=light.position)
# dummy_vehicle.set_static(True)
# if not hasattr(env.engine, "dummy_vehicle"):
# env.engine.dummy_vehicle = [dummy_vehicle.id]
# else:
# env.engine.dummy_vehicle.append(dummy_vehicle.id)
light.set_red()
test_success = False
for s in range(1, 100):
env.step([0, 1])
if env.agent.red_light:
test_success = True
break
assert test_success
light.destroy()
# yellow
env.reset()
light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0])
light.set_yellow()
test_success = False
light.set_green()
for s in range(1, 100):
env.step([0, 1])
if env.agent.yellow_light:
test_success = True
break
assert test_success
light.destroy()
# env.engine.dummy_vehicle.remove(dummy_vehicle.id)
# env.engine.clear_objects([dummy_vehicle.id], force_destroy=True)

# yellow
# env.reset()
# light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0])
# light.set_yellow()
# test_success = False
# for s in range(1, 100):
# env.step([0, 1])
# if env.agent.yellow_light:
# print('[Successfully detected] Yellow light')
# test_success = True
# # break
# assert test_success
# light.destroy()

finally:
env.close()


if __name__ == "__main__":
test_traffic_light(True, manual_control=True)
test_traffic_light(render=True, manual_control=False, debug=False)

0 comments on commit d52876a

Please sign in to comment.