Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static data and model for hybrid link gnn #27

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
577 changes: 577 additions & 0 deletions examples/static_example.py

Large diffs are not rendered by default.

32 changes: 26 additions & 6 deletions hybridgnn/nn/models/hybridgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,20 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
is_static: Optional[bool] = False,
num_src_nodes: Optional[int] = None,
src_entity_table: Optional[str] = None,
) -> None:
super().__init__(data, col_stats_dict, rhs_emb_mode, dst_entity_table,
num_nodes, embedding_dim)
super().__init__(
data,
col_stats_dict,
rhs_emb_mode,
dst_entity_table,
num_nodes,
embedding_dim,
num_src_nodes,
src_entity_table,
)

self.encoder = HeteroEncoder(
channels=channels,
Expand Down Expand Up @@ -76,6 +87,7 @@ def __init__(
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)
self.channels = channels
self.is_static = is_static

self.reset_parameters()

Expand All @@ -90,6 +102,8 @@ def reset_parameters(self) -> None:
self.lin_offset_embgnn.reset_parameters()
self.lin_offset_idgnn.reset_parameters()
self.lhs_projector.reset_parameters()
if self.lhs_embedding is not None:
self.lhs_embedding.reset_parameters()

def forward(
self,
Expand All @@ -100,14 +114,20 @@ def forward(
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)

if self.lhs_embedding is not None:
lhs_embedding = self.lhs_embedding()[batch[entity_table].n_id]
x_dict[entity_table] = lhs_embedding

# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
if not self.is_static:
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
Expand Down
11 changes: 7 additions & 4 deletions hybridgnn/nn/models/idgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
is_static=True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
)

self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.is_static = is_static
self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -85,11 +87,12 @@ def forward(
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)
if not self.is_static:
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
Expand Down
32 changes: 31 additions & 1 deletion hybridgnn/nn/models/rhsembeddinggnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torch_frame.data.stats import StatType
Expand All @@ -19,6 +19,8 @@ def __init__(
dst_entity_table: str,
num_nodes: int,
embedding_dim: int,
num_src_nodes: Optional[int] = None,
src_entity_table: Optional[str] = None,
):
super().__init__()
stype_encoder_dict = {
Expand All @@ -36,19 +38,47 @@ def __init__(
feat=data[dst_entity_table]['tf'],
)

self.lhs_embedding = None
if num_src_nodes is not None:
assert src_entity_table is not None

src_stype_encoder_dict = {
k: v[0]()
for k, v in DEFAULT_STYPE_ENCODER_DICT.items()
if k in data[src_entity_table]['tf'].col_names_dict.keys()
}

self.lhs_embedding = RHSEmbedding(
emb_mode=rhs_emb_mode,
embedding_dim=embedding_dim,
num_nodes=num_src_nodes,
col_stats=col_stats_dict[src_entity_table],
col_names_dict=data[src_entity_table]['tf'].col_names_dict,
stype_encoder_dict=src_stype_encoder_dict,
feat=data[src_entity_table]['tf'],
)

def reset_parameters(self):
self.rhs_embedding.reset_parameters()
if self.lhs_embedding is not None:
self.lhs_embedding.reset_parameters()

def to(self, *args, **kwargs) -> Self:
# Explicitly call `to` on the RHS embedding to move caches to the
# device.
self.rhs_embedding.to(*args, **kwargs)
if self.lhs_embedding is not None:
self.lhs_embedding.to(*args, **kwargs)
return super().to(*args, **kwargs)

def cpu(self) -> Self:
self.rhs_embedding.cpu()
if self.lhs_embedding is not None:
self.lhs_embedding.cpu()
return super().cpu()

def cuda(self, *args, **kwargs) -> Self:
self.rhs_embedding.cuda(*args, **kwargs)
if self.lhs_embedding is not None:
self.lhs_embedding.cuda(*args, **kwargs)
return super().cuda(*args, **kwargs)
13 changes: 9 additions & 4 deletions hybridgnn/nn/models/shallowrhsgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
is_static: Optional[bool] = False,
) -> None:
super().__init__(data, col_stats_dict, rhs_emb_mode, dst_entity_table,
num_nodes, embedding_dim)
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(
)
self.lhs_projector = torch.nn.Linear(channels, embedding_dim)
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.is_static = is_static

self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -94,11 +97,13 @@ def forward(
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
if not self.is_static:
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
Expand Down
1 change: 1 addition & 0 deletions static_data/amazon-book/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Look for the full dataset? Please visit the [websit](http://jmcauley.ucsd.edu/data/amazon).
Loading