Skip to content

Commit

Permalink
add validity checker in HGS-based local search
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Jan 12, 2025
1 parent 99a05f3 commit d038bff
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
2 changes: 1 addition & 1 deletion rl4co/envs/routing/cvrp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor):
# Cannot use less than 0
used_cap[used_cap < 0] = 0
assert (
used_cap <= td["vehicle_capacity"] + 1e-5
used_cap <= td["vehicle_capacity"][:, 0] + 1e-5
).all(), "Used more than capacity"

@staticmethod
Expand Down
77 changes: 59 additions & 18 deletions rl4co/envs/routing/cvrp/local_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_lib_filename(hgs_dir: str) -> str:
raise FileNotFoundError(f"Shared library file `{path}` not found")
return path


# Check if HGS-CVRP is installed
hgs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "HGS-CVRP")
try:
Expand All @@ -46,7 +47,7 @@ def get_lib_filename(hgs_dir: str) -> str:
log.info("HGS-CVRP is installed successfully.")


def local_search(td: TensorDict, actions: torch.Tensor, max_iterations: int = 1000):
def local_search(td: TensorDict, actions: torch.Tensor, max_iterations: int = 1000) -> torch.Tensor:
"""
Improve the solution using local search for CVRP, based on PyVRP.
Expand Down Expand Up @@ -96,8 +97,21 @@ def local_search(td: TensorDict, actions: torch.Tensor, max_iterations: int = 10
# Remove heading and tailing zeros
max_pos = np.max(np.where(new_actions != 0)[1])
new_actions = new_actions[:, 1: max_pos + 1]

return torch.from_numpy(new_actions).to(td.device)
new_actions = torch.from_numpy(new_actions).to(td.device)

# Check the validity of the solution and use the original solution if the new solution is invalid
valid = check_validity(td, new_actions)
import pdb; pdb.set_trace()
if not valid.all():
orig_valid_actions = actions[~valid]
# pad if needed
orig_max_pos = torch.max(torch.where(orig_valid_actions != 0)[1]) + 1
if orig_max_pos > max_pos:
new_actions = torch.nn.functional.pad(
new_actions, (0, orig_max_pos - max_pos, 0, 0), mode="constant", value=0 # type: ignore
)
new_actions[~valid] = orig_valid_actions[:, :orig_max_pos]
return new_actions


def get_subroutes(path, end_with_zero = True) -> List[List[int]]:
Expand All @@ -122,6 +136,41 @@ def merge_subroutes(subroutes, length):
return route


def check_validity(td: TensorDict, actions: torch.Tensor) -> torch.Tensor:
"""
Check the validity of the solution for CVRP and return a boolean tensor.
Modified from CVRPEnv.check_solution_validity in rl4co/envs/routing/cvrp/env.py
"""
# Check if tour is valid, i.e. contain 0 to n-1
batch_size, graph_size = td["demand"].size()
sorted_pi = actions.data.sort(1)[0]

# Sorting it should give all zeros at front and then 1...n
assert (
torch.arange(1, graph_size + 1, out=sorted_pi.data.new())
.view(1, -1)
.expand(batch_size, graph_size)
== sorted_pi[:, -graph_size:]
).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour"

# Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative)
demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1)
d = demand_with_depot.gather(1, actions)

used_cap = torch.zeros_like(td["demand"][:, 0])
valid = torch.ones(batch_size, dtype=torch.bool)
for i in range(actions.size(1)):
used_cap += d[
:, i
] # This will reset/make capacity negative if i == 0, e.g. depot visited
# Cannot use less than 0
used_cap[used_cap < 0] = 0
valid &= (
used_cap <= td["vehicle_capacity"][:, 0] + 1e-5
)
return valid


########################## HGS-CVRP python wrapper ###########################
# Adapted from https://github.com/chkwon/PyHygese/blob/master/hygese/hygese.py

Expand All @@ -132,14 +181,14 @@ def merge_subroutes(subroutes, length):
C_DBL_MAX = sys.float_info.max


def write_routes(routes: List[List[int]], filepath: str):
def write_routes(routes: List[np.ndarray], filepath: str):
with open(filepath, "w") as f:
for i, r in enumerate(routes):
f.write(f"Route #{i + 1}: "+' '.join([str(x) for x in r if x > 0])+"\n")
return


def read_routes(filepath):
def read_routes(filepath) -> List[np.ndarray]:
routes = []
with open(filepath, "r") as f:
while 1:
Expand Down Expand Up @@ -277,7 +326,7 @@ def __init__(self, parameters=AlgorithmParameters(), verbose=False):
self._c_api_delete_sol.restype = None
self._c_api_delete_sol.argtypes = [POINTER(_Solution)]

def local_search(self, data, routes: List[List[int]], count:int = 1,rounding=True,):
def local_search(self, data, routes: List[np.ndarray], count: int = 1) -> List[np.ndarray]:
# required data
demand = np.asarray(data["demands"])
vehicle_capacity = data["vehicle_capacity"]
Expand Down Expand Up @@ -306,8 +355,6 @@ def local_search(self, data, routes: List[List[int]], count:int = 1,rounding=Tru
else:
is_duration_constraint = True

is_rounding_integer = rounding

x_coords = data.get("x_coordinates")
y_coords = data.get("y_coordinates")
dist_mtx = data.get("distance_matrix")
Expand Down Expand Up @@ -351,12 +398,10 @@ def local_search(self, data, routes: List[List[int]], count:int = 1,rounding=Tru
callid,
count,
)

result = read_routes(resultpath)
except Exception as e:
pass
# print(routes)
# print([demand[r].sum() for r in routes])
print(e)
result = routes
else:
os.remove(resultpath)
finally:
Expand Down Expand Up @@ -416,7 +461,7 @@ def _local_search(
return result


def swapstar(demands, matrix, positions, routes, count=1):
def swapstar(demands, matrix, positions, routes: List[np.ndarray], count=1):
ap = AlgorithmParameters()
hgs_solver = Solver(parameters=ap, verbose=False)

Expand All @@ -433,9 +478,5 @@ def swapstar(demands, matrix, positions, routes, count=1):

# Solve with calculated distances
data['distance_matrix'] = matrix
try:
result = hgs_solver.local_search(data, routes, count)
except Exception as e:
print(e)
return routes
result = hgs_solver.local_search(data, routes, count)
return result

0 comments on commit d038bff

Please sign in to comment.