Skip to content

Commit

Permalink
fix bug in CVRPEdgeEmbedding and some minor refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Jan 12, 2025
1 parent d038bff commit 8270199
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 30 deletions.
9 changes: 6 additions & 3 deletions configs/experiment/routing/deepaco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ model:
train_data_size: 400
val_data_size: 20
test_data_size: 100
optimizer: "AdamW"
optimizer_kwargs:
lr: 1e-3
weight_decay: 0
Expand All @@ -45,8 +46,8 @@ model:
n_iterations:
train: 1 # unused value
val: 5
test: 20
temperature: 0.5
test: 10
temperature: 1.0
top_p: 0.0
top_k: 0
aco_kwargs:
Expand All @@ -55,12 +56,14 @@ model:
decay: 0.95
use_local_search: True
use_nls: True
n_perturbations: 5
local_search_params:
max_iterations: 1000
perturbation_params:
max_iterations: 20
k_sparse: 5 # this should be adjusted based on the `num_loc` value

trainer:
max_epochs: 50

seed: 1234
seed: 1234
12 changes: 8 additions & 4 deletions configs/experiment/routing/gfacs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ model:
train_data_size: 400
val_data_size: 20
test_data_size: 100
optimizer: "AdamW"
optimizer_kwargs:
lr: 1e-3
weight_decay: 0
Expand All @@ -45,8 +46,8 @@ model:
n_iterations:
train: 1 # unused value
val: 5
test: 20
temperature: 0.5
test: 10
temperature: 1.0
top_p: 0.0
top_k: 0
aco_kwargs:
Expand All @@ -55,15 +56,18 @@ model:
decay: 0.95
use_local_search: True
use_nls: True
n_perturbations: 5
local_search_params:
max_iterations: 1000
perturbation_params:
max_iterations: 20
k_sparse: 5 # this should be adjusted based on the `num_loc` value

beta: 1000
beta_min: 100
beta_max: 500
beta_flat_epochs: 5

trainer:
max_epochs: 50

seed: 1234
seed: 1234
10 changes: 5 additions & 5 deletions rl4co/envs/routing/cvrp/local_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@ def local_search(td: TensorDict, actions: torch.Tensor, max_iterations: int = 10
new_actions = torch.from_numpy(new_actions).to(td.device)

# Check the validity of the solution and use the original solution if the new solution is invalid
valid = check_validity(td, new_actions)
import pdb; pdb.set_trace()
if not valid.all():
orig_valid_actions = actions[~valid]
isvalid = check_validity(td, new_actions)
if not isvalid.all():
new_actions[~isvalid] = 0
orig_valid_actions = actions[~isvalid]
# pad if needed
orig_max_pos = torch.max(torch.where(orig_valid_actions != 0)[1]) + 1
if orig_max_pos > max_pos:
new_actions = torch.nn.functional.pad(
new_actions, (0, orig_max_pos - max_pos, 0, 0), mode="constant", value=0 # type: ignore
)
new_actions[~valid] = orig_valid_actions[:, :orig_max_pos]
new_actions[~isvalid, :orig_max_pos] = orig_valid_actions[:, :orig_max_pos]
return new_actions


Expand Down
72 changes: 65 additions & 7 deletions rl4co/models/nn/env_embeddings/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def env_edge_embedding(env_name: str, config: dict) -> nn.Module:
embedding_registry = {
"tsp": TSPEdgeEmbedding,
"atsp": ATSPEdgeEmbedding,
"cvrp": TSPEdgeEmbedding,
"cvrp": CVRPEdgeEmbedding,
"sdvrp": TSPEdgeEmbedding,
"pctsp": TSPEdgeEmbedding,
"pctsp": CVRPEdgeEmbedding,
"spctsp": TSPEdgeEmbedding,
"op": TSPEdgeEmbedding,
"op": CVRPEdgeEmbedding,
"dpp": TSPEdgeEmbedding,
"mdpp": TSPEdgeEmbedding,
"pdp": TSPEdgeEmbedding,
Expand Down Expand Up @@ -93,12 +93,70 @@ def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tens
edge_index = get_full_graph_edge_index(
cost_matrix.shape[0], self_loop=False
).to(cost_matrix.device)
edge_attr = cost_matrix[edge_index[0], edge_index[1]]
edge_attr = cost_matrix[edge_index[0], edge_index[1]].unsqueeze(-1)

graph = Data(
x=init_embeddings[index],
edge_index=edge_index,
edge_attr=edge_attr,
x=init_embeddings[index], edge_index=edge_index, edge_attr=edge_attr
)
graph_data.append(graph)

batch = Batch.from_data_list(graph_data)
batch.edge_attr = self.edge_embed(batch.edge_attr)
return batch

class CVRPEdgeEmbedding(TSPEdgeEmbedding):
"""Edge embedding module for the Capacitated Vehicle Routing Problem (CVRP).
Unlike the TSP, all nodes in the CVRP should be connected to the depot,
so each node will have k_sparse + 1 edges.
"""

def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tensor):
"""Convert batched cost_matrix to batched PyG graph, and calculate edge embeddings.
Args:
batch_cost_matrix: Tensor of shape [batch_size, n, n]
init_embedding: init embeddings
"""
graph_data = []
for index, cost_matrix in enumerate(batch_cost_matrix):
if self.sparsify:
edge_index, edge_attr = sparsify_graph(
cost_matrix[1:, 1:], self.k_sparse, self_loop=False
)
edge_index = edge_index + 1 # because we removed the depot
# Note here
edge_index = torch.cat(
[
edge_index,
# All nodes should be connected to the depot
torch.stack(
[
torch.arange(1, cost_matrix.shape[0]),
torch.zeros(cost_matrix.shape[0] - 1, dtype=torch.long),
]
).to(edge_index.device),
# Depot should be connected to all nodes
torch.stack(
[
torch.zeros(cost_matrix.shape[0] - 1, dtype=torch.long),
torch.arange(1, cost_matrix.shape[0]),
]
).to(edge_index.device),
],
dim=1,
)
edge_attr = torch.cat(
[edge_attr, cost_matrix[1:, [0]], cost_matrix[[0], 1:].t()], dim=0
)

else:
edge_index = get_full_graph_edge_index(
cost_matrix.shape[0], self_loop=False
).to(cost_matrix.device)
edge_attr = cost_matrix[edge_index[0], edge_index[1]].unsqueeze(-1)

graph = Data(
x=init_embeddings[index], edge_index=edge_index, edge_attr=edge_attr
)
graph_data.append(graph)

Expand Down
4 changes: 1 addition & 3 deletions rl4co/models/zoo/deepaco/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
**encoder_kwargs,
):
if encoder is None:
encoder = NARGNNEncoder(**encoder_kwargs)
encoder = NARGNNEncoder(env_name=env_name, **encoder_kwargs)

super(DeepACOPolicy, self).__init__(
encoder=encoder,
Expand Down Expand Up @@ -141,8 +141,6 @@ def forward(
assert self.top_p <= 1.0, "top-p should be in (0, 1]."
heatmap_logits = modify_logits_for_top_p_filtering(heatmap_logits, self.top_p)

heatmap_logits = torch.nan_to_num(heatmap_logits, nan=math.log(1e-10), neginf=math.log(1e-10))

aco = self.aco_class(heatmap_logits, n_ants=n_ants, **self.aco_kwargs)
td, actions, reward = aco.run(td_initial, env, self.n_iterations[phase])

Expand Down
23 changes: 20 additions & 3 deletions rl4co/models/zoo/gfacs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(
ls_reward_aug_W: float = 0.95,
policy_kwargs: dict = {},
baseline_kwargs: dict = {},
beta: float = 1.0,
beta_min: float = 1.0,
beta_max: float = 1.0,
beta_flat_epochs: int = 0,
**kwargs,
):
if policy is None:
Expand All @@ -45,10 +47,25 @@ def __init__(
)

super().__init__(
env, policy, baseline, train_with_local_search, ls_reward_aug_W, policy_kwargs, baseline_kwargs, **kwargs
env,
policy,
baseline,
train_with_local_search,
ls_reward_aug_W,
policy_kwargs,
baseline_kwargs,
**kwargs,
)

self.beta = beta
self.beta_min = beta_min
self.beta_max = beta_max
self.beta_flat_epochs = beta_flat_epochs

@property
def beta(self) -> float:
return self.beta_min + (self.beta_max - self.beta_min) * min(
math.log(self.current_epoch + 1) / math.log(self.trainer.max_epochs - self.beta_flat_epochs), 1.0
)

def calculate_loss(
self,
Expand Down
4 changes: 1 addition & 3 deletions rl4co/models/zoo/gfacs/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
):
if encoder is None:
encoder_kwargs["z_out_dim"] = 2 if train_with_local_search else 1
encoder = GFACSEncoder(**encoder_kwargs)
encoder = GFACSEncoder(env_name=env_name, **encoder_kwargs)

super().__init__(
encoder=encoder,
Expand Down Expand Up @@ -169,8 +169,6 @@ def forward(
if self.top_p > 0:
assert self.top_p <= 1.0, "top-p should be in (0, 1]."
heatmap_logits = modify_logits_for_top_p_filtering(heatmap_logits, self.top_p)

heatmap_logits = torch.nan_to_num(heatmap_logits, nan=math.log(1e-10), neginf=math.log(1e-10))

aco = self.aco_class(heatmap_logits, n_ants=n_ants, **self.aco_kwargs)
td, actions, reward = aco.run(td_initial, env, self.n_iterations[phase])
Expand Down
11 changes: 9 additions & 2 deletions rl4co/models/zoo/nargnn/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,15 @@ def _make_heatmap_logits(self, batch_graph: Batch) -> Tensor: # type: ignore
# if self.undirected_graph:
# heatmap = (heatmap + heatmap.transpose(1, 2)) * 0.5

heatmap += 1e-10 if heatmap.dtype != torch.float16 else 3e-8
# 3e-8 is the smallest positive number such that log(3e-8) is not -inf
# Avoid log(0) by adding a small value
if heatmap.dtype == torch.float32 or heatmap.dtype == torch.bfloat16:
small_value = 1e-12
elif heatmap.dtype == torch.float16:
small_value = 3e-8 # the smallest positive number such that log(small_value) is not -inf
else:
raise ValueError(f"Unsupported dtype: {heatmap.dtype}")

heatmap += small_value
heatmap_logits = torch.log(heatmap)

return heatmap_logits
Expand Down

0 comments on commit 8270199

Please sign in to comment.