Skip to content

Commit

Permalink
Merge branch 'main' into znado-lint
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed Mar 9, 2022
2 parents 76825bf + 4bc8247 commit e5b5621
Show file tree
Hide file tree
Showing 21 changed files with 677 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@
**JAX (GPU)**

```bash
pip3 install -e .[jax-gpu] -f 'https://storage.googleapis.com/jax-releases/jax_releases.html'
pip3 install -e .[jax_gpu] -f 'https://storage.googleapis.com/jax-releases/jax_releases.html'
```

**JAX (CPU)**

```bash
pip3 install -e .[jax-cpu]
pip3 install -e .[jax_cpu]
```

**PyTorch**
Expand Down
6 changes: 3 additions & 3 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
import enum
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union


class LossType(enum.Enum):
Expand Down Expand Up @@ -203,8 +203,8 @@ def output_activation_fn(self, logits_batch: Tensor,
def loss_fn(
self,
label_batch: Tensor, # Dense (not one-hot) labels.
logits_batch: Tensor
) -> Tensor: # differentiable
logits_batch: Tensor,
mask_batch: Optional[Tensor]) -> Tensor: # differentiable
"""return oned_array_of_losses_per_example"""

@abc.abstractmethod
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions algorithmic_efficiency/workloads/ogbg/ogbg_jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## ogbg-molpcba classification

Based on the [Flax ogbg-molpcba example](https://github.com/google/flax/tree/main/examples/ogbg_molpcba).
Empty file.
161 changes: 161 additions & 0 deletions algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Forked from Flax example which can be found here:
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py
# and from the init2winit fork here
# https://github.com/google/init2winit/blob/master/init2winit/dataset_lib/ogbg_molpcba.py
"""Exposes the ogbg-molpcba dataset in a convenient format."""

import jax
import jraph
import numpy as np
import tensorflow_datasets as tfds

AVG_NODES_PER_GRAPH = 26
AVG_EDGES_PER_GRAPH = 56


def _load_dataset(split, should_shuffle, data_rng, data_dir):
"""Loads a dataset split from TFDS."""
if should_shuffle:
file_data_rng, dataset_data_rng = jax.random.split(data_rng)
file_data_rng = file_data_rng[0]
dataset_data_rng = dataset_data_rng[0]
else:
file_data_rng = None
dataset_data_rng = None

read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng)
dataset = tfds.load(
'ogbg_molpcba',
split=split,
shuffle_files=should_shuffle,
read_config=read_config,
data_dir=data_dir)

if should_shuffle:
dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15)
dataset = dataset.repeat()

return dataset


def _to_jraph(example):
"""Converts an example graph to jraph.GraphsTuple."""
example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access
edge_feat = example['edge_feat']
node_feat = example['node_feat']
edge_index = example['edge_index']
labels = example['labels']
num_nodes = example['num_nodes']

senders = edge_index[:, 0]
receivers = edge_index[:, 1]

return jraph.GraphsTuple(
n_node=num_nodes,
n_edge=np.array([len(edge_index) * 2]),
nodes=node_feat,
edges=np.concatenate([edge_feat, edge_feat]),
# Make the edges bidirectional
senders=np.concatenate([senders, receivers]),
receivers=np.concatenate([receivers, senders]),
# Keep the labels with the graph for batching. They will be removed
# in the processed batch.
globals=np.expand_dims(labels, axis=0))


def _get_weights_by_nan_and_padding(labels, padding_mask):
"""Handles NaNs and padding in labels.
Sets all the weights from examples coming from padding to 0. Changes all NaNs
in labels to 0s and sets the corresponding per-label weight to 0.
Args:
labels: Labels including labels from padded examples
padding_mask: Binary array of which examples are padding
Returns:
tuple of (processed labels, corresponding weights)
"""
nan_mask = np.isnan(labels)
replaced_labels = np.copy(labels)
np.place(replaced_labels, nan_mask, 0)

weights = 1.0 - nan_mask
# Weights for all labels of a padded element will be 0
weights = weights * padding_mask[:, None]
return replaced_labels, weights


def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
"""Turns a per-example iterator into a batched iterator.
Constructs the batch from num_shards smaller batches, so that we can easily
shard the batch to multiple devices during training. We use
dynamic batching, so we specify some max number of graphs/nodes/edges, add
as many graphs as we can, and then pad to the max values.
Args:
dataset_iter: The TFDS dataset iterator.
global_batch_size: How many average-sized graphs go into the batch.
num_shards: How many devices we should be able to shard the batch into.
Yields:
Batch in the init2winit format. Each field is a list of num_shards separate
smaller batches.
"""
if not num_shards:
num_shards = jax.device_count()

# We will construct num_shards smaller batches and then put them together.
per_device_batch_size = global_batch_size / num_shards

max_n_nodes = AVG_NODES_PER_GRAPH * per_device_batch_size
max_n_edges = AVG_EDGES_PER_GRAPH * per_device_batch_size
max_n_graphs = per_device_batch_size

jraph_iter = map(_to_jraph, dataset_iter)
batched_iter = jraph.dynamically_batch(jraph_iter,
max_n_nodes + 1,
max_n_edges,
max_n_graphs + 1)

count = 0
graphs_shards = []
labels_shards = []
weights_shards = []

for batched_graph in batched_iter:
count += 1

# Separate the labels from the graph
labels = batched_graph.globals
graph = batched_graph._replace(globals={})

replaced_labels, weights = _get_weights_by_nan_and_padding(
labels, jraph.get_graph_padding_mask(graph))

graphs_shards.append(graph)
labels_shards.append(replaced_labels)
weights_shards.append(weights)

if count == num_shards:

def f(x):
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])

graphs_shards = f(graphs_shards)
labels_shards = f(labels_shards)
weights_shards = f(weights_shards)
yield (graphs_shards, labels_shards, weights_shards)

count = 0
graphs_shards = []
labels_shards = []
weights_shards = []


def get_dataset_iter(split, data_rng, data_dir, global_batch_size):
ds = _load_dataset(
split,
should_shuffle=(split == 'train'),
data_rng=data_rng,
data_dir=data_dir)
return _get_batch_iterator(iter(ds), global_batch_size)
58 changes: 58 additions & 0 deletions algorithmic_efficiency/workloads/ogbg/ogbg_jax/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Forked from Flax example which can be found here:
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py

from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.metrics import average_precision_score


def predictions_match_labels(*,
logits: jnp.ndarray,
labels: jnp.ndarray,
**kwargs) -> jnp.ndarray:
"""Returns a binary array indicating where predictions match the labels."""
del kwargs # Unused.
preds = (logits > 0)
return (preds == labels).astype(jnp.float32)


@flax.struct.dataclass
class MeanAveragePrecision(
metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
"""Computes the mean average precision (mAP) over different tasks."""

def compute(self):
# Matches the official OGB evaluation scheme for mean average precision.
labels = self.values['labels']
logits = self.values['logits']
mask = self.values['mask']

mask = mask.astype(np.bool)

probs = jax.nn.sigmoid(logits)
num_tasks = labels.shape[1]
average_precisions = np.full(num_tasks, np.nan)

# Note that this code is slow (~1 minute).
for task in range(num_tasks):
# AP is only defined when there is at least one negative data
# and at least one positive data.
if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0:
is_labeled = mask[:, task]
average_precisions[task] = average_precision_score(
labels[is_labeled, task], probs[is_labeled, task])

# When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning.
if np.isnan(average_precisions).all():
return np.nan
return np.nanmean(average_precisions)


@flax.struct.dataclass
class EvalMetrics(metrics.Collection):
accuracy: metrics.Average.from_fun(predictions_match_labels)
loss: metrics.Average.from_output('loss')
mean_average_precision: MeanAveragePrecision
69 changes: 69 additions & 0 deletions algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Forked from the init2winit implementation here
# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
from typing import Tuple

from flax import linen as nn
import jax.numpy as jnp
import jraph


def _make_embed(latent_dim):

def make_fn(inputs):
return nn.Dense(features=latent_dim)(inputs)

return make_fn


def _make_mlp(hidden_dims, dropout):
"""Creates a MLP with specified dimensions."""

@jraph.concatenated_args
def make_fn(inputs):
x = inputs
for dim in hidden_dims:
x = nn.Dense(features=dim)(x)
x = nn.LayerNorm()(x)
x = nn.relu(x)
x = dropout(x)
return x

return make_fn


class GNN(nn.Module):
"""Defines a graph network.
The model assumes the input data is a jraph.GraphsTuple without global
variables. The final prediction will be encoded in the globals.
"""
num_outputs: int = 128
latent_dim: int = 256
hidden_dims: Tuple[int] = (256,)
dropout_rate: float = 0.1
num_message_passing_steps: int = 5

@nn.compact
def __call__(self, graph, train):
dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train)

graph = graph._replace(
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))

embedder = jraph.GraphMapFeatures(
embed_node_fn=_make_embed(self.latent_dim),
embed_edge_fn=_make_embed(self.latent_dim))
graph = embedder(graph)

for _ in range(self.num_message_passing_steps):
net = jraph.GraphNetwork(
update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout),
update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout),
update_global_fn=_make_mlp(self.hidden_dims, dropout=dropout))

graph = net(graph)

# Map globals to represent the final result
decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs))
graph = decoder(graph)

return graph.globals
Loading

0 comments on commit e5b5621

Please sign in to comment.