-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
677 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
## ogbg-molpcba classification | ||
|
||
Based on the [Flax ogbg-molpcba example](https://github.com/google/flax/tree/main/examples/ogbg_molpcba). |
Empty file.
161 changes: 161 additions & 0 deletions
161
algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Forked from Flax example which can be found here: | ||
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py | ||
# and from the init2winit fork here | ||
# https://github.com/google/init2winit/blob/master/init2winit/dataset_lib/ogbg_molpcba.py | ||
"""Exposes the ogbg-molpcba dataset in a convenient format.""" | ||
|
||
import jax | ||
import jraph | ||
import numpy as np | ||
import tensorflow_datasets as tfds | ||
|
||
AVG_NODES_PER_GRAPH = 26 | ||
AVG_EDGES_PER_GRAPH = 56 | ||
|
||
|
||
def _load_dataset(split, should_shuffle, data_rng, data_dir): | ||
"""Loads a dataset split from TFDS.""" | ||
if should_shuffle: | ||
file_data_rng, dataset_data_rng = jax.random.split(data_rng) | ||
file_data_rng = file_data_rng[0] | ||
dataset_data_rng = dataset_data_rng[0] | ||
else: | ||
file_data_rng = None | ||
dataset_data_rng = None | ||
|
||
read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng) | ||
dataset = tfds.load( | ||
'ogbg_molpcba', | ||
split=split, | ||
shuffle_files=should_shuffle, | ||
read_config=read_config, | ||
data_dir=data_dir) | ||
|
||
if should_shuffle: | ||
dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15) | ||
dataset = dataset.repeat() | ||
|
||
return dataset | ||
|
||
|
||
def _to_jraph(example): | ||
"""Converts an example graph to jraph.GraphsTuple.""" | ||
example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access | ||
edge_feat = example['edge_feat'] | ||
node_feat = example['node_feat'] | ||
edge_index = example['edge_index'] | ||
labels = example['labels'] | ||
num_nodes = example['num_nodes'] | ||
|
||
senders = edge_index[:, 0] | ||
receivers = edge_index[:, 1] | ||
|
||
return jraph.GraphsTuple( | ||
n_node=num_nodes, | ||
n_edge=np.array([len(edge_index) * 2]), | ||
nodes=node_feat, | ||
edges=np.concatenate([edge_feat, edge_feat]), | ||
# Make the edges bidirectional | ||
senders=np.concatenate([senders, receivers]), | ||
receivers=np.concatenate([receivers, senders]), | ||
# Keep the labels with the graph for batching. They will be removed | ||
# in the processed batch. | ||
globals=np.expand_dims(labels, axis=0)) | ||
|
||
|
||
def _get_weights_by_nan_and_padding(labels, padding_mask): | ||
"""Handles NaNs and padding in labels. | ||
Sets all the weights from examples coming from padding to 0. Changes all NaNs | ||
in labels to 0s and sets the corresponding per-label weight to 0. | ||
Args: | ||
labels: Labels including labels from padded examples | ||
padding_mask: Binary array of which examples are padding | ||
Returns: | ||
tuple of (processed labels, corresponding weights) | ||
""" | ||
nan_mask = np.isnan(labels) | ||
replaced_labels = np.copy(labels) | ||
np.place(replaced_labels, nan_mask, 0) | ||
|
||
weights = 1.0 - nan_mask | ||
# Weights for all labels of a padded element will be 0 | ||
weights = weights * padding_mask[:, None] | ||
return replaced_labels, weights | ||
|
||
|
||
def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): | ||
"""Turns a per-example iterator into a batched iterator. | ||
Constructs the batch from num_shards smaller batches, so that we can easily | ||
shard the batch to multiple devices during training. We use | ||
dynamic batching, so we specify some max number of graphs/nodes/edges, add | ||
as many graphs as we can, and then pad to the max values. | ||
Args: | ||
dataset_iter: The TFDS dataset iterator. | ||
global_batch_size: How many average-sized graphs go into the batch. | ||
num_shards: How many devices we should be able to shard the batch into. | ||
Yields: | ||
Batch in the init2winit format. Each field is a list of num_shards separate | ||
smaller batches. | ||
""" | ||
if not num_shards: | ||
num_shards = jax.device_count() | ||
|
||
# We will construct num_shards smaller batches and then put them together. | ||
per_device_batch_size = global_batch_size / num_shards | ||
|
||
max_n_nodes = AVG_NODES_PER_GRAPH * per_device_batch_size | ||
max_n_edges = AVG_EDGES_PER_GRAPH * per_device_batch_size | ||
max_n_graphs = per_device_batch_size | ||
|
||
jraph_iter = map(_to_jraph, dataset_iter) | ||
batched_iter = jraph.dynamically_batch(jraph_iter, | ||
max_n_nodes + 1, | ||
max_n_edges, | ||
max_n_graphs + 1) | ||
|
||
count = 0 | ||
graphs_shards = [] | ||
labels_shards = [] | ||
weights_shards = [] | ||
|
||
for batched_graph in batched_iter: | ||
count += 1 | ||
|
||
# Separate the labels from the graph | ||
labels = batched_graph.globals | ||
graph = batched_graph._replace(globals={}) | ||
|
||
replaced_labels, weights = _get_weights_by_nan_and_padding( | ||
labels, jraph.get_graph_padding_mask(graph)) | ||
|
||
graphs_shards.append(graph) | ||
labels_shards.append(replaced_labels) | ||
weights_shards.append(weights) | ||
|
||
if count == num_shards: | ||
|
||
def f(x): | ||
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) | ||
|
||
graphs_shards = f(graphs_shards) | ||
labels_shards = f(labels_shards) | ||
weights_shards = f(weights_shards) | ||
yield (graphs_shards, labels_shards, weights_shards) | ||
|
||
count = 0 | ||
graphs_shards = [] | ||
labels_shards = [] | ||
weights_shards = [] | ||
|
||
|
||
def get_dataset_iter(split, data_rng, data_dir, global_batch_size): | ||
ds = _load_dataset( | ||
split, | ||
should_shuffle=(split == 'train'), | ||
data_rng=data_rng, | ||
data_dir=data_dir) | ||
return _get_batch_iterator(iter(ds), global_batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Forked from Flax example which can be found here: | ||
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py | ||
|
||
from clu import metrics | ||
import flax | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from sklearn.metrics import average_precision_score | ||
|
||
|
||
def predictions_match_labels(*, | ||
logits: jnp.ndarray, | ||
labels: jnp.ndarray, | ||
**kwargs) -> jnp.ndarray: | ||
"""Returns a binary array indicating where predictions match the labels.""" | ||
del kwargs # Unused. | ||
preds = (logits > 0) | ||
return (preds == labels).astype(jnp.float32) | ||
|
||
|
||
@flax.struct.dataclass | ||
class MeanAveragePrecision( | ||
metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): | ||
"""Computes the mean average precision (mAP) over different tasks.""" | ||
|
||
def compute(self): | ||
# Matches the official OGB evaluation scheme for mean average precision. | ||
labels = self.values['labels'] | ||
logits = self.values['logits'] | ||
mask = self.values['mask'] | ||
|
||
mask = mask.astype(np.bool) | ||
|
||
probs = jax.nn.sigmoid(logits) | ||
num_tasks = labels.shape[1] | ||
average_precisions = np.full(num_tasks, np.nan) | ||
|
||
# Note that this code is slow (~1 minute). | ||
for task in range(num_tasks): | ||
# AP is only defined when there is at least one negative data | ||
# and at least one positive data. | ||
if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: | ||
is_labeled = mask[:, task] | ||
average_precisions[task] = average_precision_score( | ||
labels[is_labeled, task], probs[is_labeled, task]) | ||
|
||
# When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. | ||
if np.isnan(average_precisions).all(): | ||
return np.nan | ||
return np.nanmean(average_precisions) | ||
|
||
|
||
@flax.struct.dataclass | ||
class EvalMetrics(metrics.Collection): | ||
accuracy: metrics.Average.from_fun(predictions_match_labels) | ||
loss: metrics.Average.from_output('loss') | ||
mean_average_precision: MeanAveragePrecision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Forked from the init2winit implementation here | ||
# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. | ||
from typing import Tuple | ||
|
||
from flax import linen as nn | ||
import jax.numpy as jnp | ||
import jraph | ||
|
||
|
||
def _make_embed(latent_dim): | ||
|
||
def make_fn(inputs): | ||
return nn.Dense(features=latent_dim)(inputs) | ||
|
||
return make_fn | ||
|
||
|
||
def _make_mlp(hidden_dims, dropout): | ||
"""Creates a MLP with specified dimensions.""" | ||
|
||
@jraph.concatenated_args | ||
def make_fn(inputs): | ||
x = inputs | ||
for dim in hidden_dims: | ||
x = nn.Dense(features=dim)(x) | ||
x = nn.LayerNorm()(x) | ||
x = nn.relu(x) | ||
x = dropout(x) | ||
return x | ||
|
||
return make_fn | ||
|
||
|
||
class GNN(nn.Module): | ||
"""Defines a graph network. | ||
The model assumes the input data is a jraph.GraphsTuple without global | ||
variables. The final prediction will be encoded in the globals. | ||
""" | ||
num_outputs: int = 128 | ||
latent_dim: int = 256 | ||
hidden_dims: Tuple[int] = (256,) | ||
dropout_rate: float = 0.1 | ||
num_message_passing_steps: int = 5 | ||
|
||
@nn.compact | ||
def __call__(self, graph, train): | ||
dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train) | ||
|
||
graph = graph._replace( | ||
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) | ||
|
||
embedder = jraph.GraphMapFeatures( | ||
embed_node_fn=_make_embed(self.latent_dim), | ||
embed_edge_fn=_make_embed(self.latent_dim)) | ||
graph = embedder(graph) | ||
|
||
for _ in range(self.num_message_passing_steps): | ||
net = jraph.GraphNetwork( | ||
update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), | ||
update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), | ||
update_global_fn=_make_mlp(self.hidden_dims, dropout=dropout)) | ||
|
||
graph = net(graph) | ||
|
||
# Map globals to represent the final result | ||
decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs)) | ||
graph = decoder(graph) | ||
|
||
return graph.globals |
Oops, something went wrong.