From 462ca51a2d3a5084b2be377969dc5473a92c8837 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 28 Oct 2021 14:12:30 +0200 Subject: [PATCH 01/46] Add OGB workload to workload registry --- submission_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 4fde34736..5e4d2746f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -41,6 +41,10 @@ 'workload_path': 'workloads/imagenet/imagenet_jax/workload.py', 'workload_class_name': 'ImagenetWorkload' }, + 'ogb_jax': { + 'workload_path': 'workloads/ogb/ogb_jax/workload.py', + 'workload_class_name': 'OGBWorkload' + }, 'wmt_jax': { 'workload_path': 'workloads/wmt/wmt_jax/workload.py', 'workload_class_name': 'WMTWorkload' @@ -325,4 +329,3 @@ def main(_): if __name__ == '__main__': app.run(main) - From 6369831566907b4a6c36b58f8c55009129389cf4 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 28 Oct 2021 14:14:44 +0200 Subject: [PATCH 02/46] Add initial OGB workload draft --- workloads/ogb/__init__.py | 0 workloads/ogb/ogb_jax/README.md | 3 + workloads/ogb/ogb_jax/__init__.py | 0 workloads/ogb/ogb_jax/input_pipeline.py | 224 ++++++++++++++++ workloads/ogb/ogb_jax/models.py | 178 +++++++++++++ workloads/ogb/ogb_jax/submission.py | 92 +++++++ .../ogb/ogb_jax/tuning_search_space.json | 3 + workloads/ogb/ogb_jax/workload.py | 239 ++++++++++++++++++ workloads/ogb/workload.py | 75 ++++++ 9 files changed, 814 insertions(+) create mode 100644 workloads/ogb/__init__.py create mode 100644 workloads/ogb/ogb_jax/README.md create mode 100644 workloads/ogb/ogb_jax/__init__.py create mode 100644 workloads/ogb/ogb_jax/input_pipeline.py create mode 100644 workloads/ogb/ogb_jax/models.py create mode 100644 workloads/ogb/ogb_jax/submission.py create mode 100644 workloads/ogb/ogb_jax/tuning_search_space.json create mode 100644 workloads/ogb/ogb_jax/workload.py create mode 100644 workloads/ogb/workload.py diff --git a/workloads/ogb/__init__.py b/workloads/ogb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/workloads/ogb/ogb_jax/README.md b/workloads/ogb/ogb_jax/README.md new file mode 100644 index 000000000..780f29f34 --- /dev/null +++ b/workloads/ogb/ogb_jax/README.md @@ -0,0 +1,3 @@ +## ogbg-molpcba classification + +Based on the [Flax ogbg-molpcba example](https://github.com/google/flax/tree/main/examples/ogbg_molpcba). diff --git a/workloads/ogb/ogb_jax/__init__.py b/workloads/ogb/ogb_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/workloads/ogb/ogb_jax/input_pipeline.py b/workloads/ogb/ogb_jax/input_pipeline.py new file mode 100644 index 000000000..8c44dda2b --- /dev/null +++ b/workloads/ogb/ogb_jax/input_pipeline.py @@ -0,0 +1,224 @@ +# Forked from Flax example which can be found here: +# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py + +"""Exposes the ogbg-molpcba dataset in a convenient format.""" + +import functools +from typing import Dict, NamedTuple +import jraph +import numpy as np +import tensorflow as tf +# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.experimental.set_visible_devices([], 'GPU') +import tensorflow_datasets as tfds + + +class GraphsTupleSize(NamedTuple): + """Helper class to represent padding and graph sizes.""" + n_node: int + n_edge: int + n_graph: int + + +def get_raw_datasets() -> Dict[str, tf.data.Dataset]: + """Returns datasets as tf.data.Dataset, organized by split.""" + ds_builder = tfds.builder('ogbg_molpcba') + ds_builder.download_and_prepare() + ds_splits = ['train', 'validation', 'test'] + datasets = { + split: ds_builder.as_dataset(split=split) for split in ds_splits + } + return datasets + + +def get_datasets(batch_size: int, + add_virtual_node: bool = True, + add_undirected_edges: bool = True, + add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: + """Returns datasets of batched GraphsTuples, organized by split.""" + if batch_size <= 1: + raise ValueError('Batch size must be > 1 to account for padding graphs.') + + # Obtain the original datasets. + datasets = get_raw_datasets() + + # Construct the GraphsTuple converter function. + convert_to_graphs_tuple_fn = functools.partial( + convert_to_graphs_tuple, + add_virtual_node=add_self_loops, + add_undirected_edges=add_undirected_edges, + add_self_loops=add_virtual_node, + ) + + # Process each split separately. + for split_name in datasets: + + # Convert to GraphsTuple. + datasets[split_name] = datasets[split_name].map( + convert_to_graphs_tuple_fn, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=True) + + # Compute the padding budget for the requested batch size. + budget = estimate_padding_budget_for_batch_size(datasets['train'], batch_size, + num_estimation_graphs=100) + + # Pad an example graph to see what the output shapes will be. + # We will use this shape information when creating the tf.data.Dataset. + example_graph = next(datasets['train'].as_numpy_iterator()) + example_padded_graph = jraph.pad_with_graphs(example_graph, *budget) + padded_graphs_spec = specs_from_graphs_tuple(example_padded_graph) + + # Process each split separately. + for split_name, dataset_split in datasets.items(): + + # Repeat and shuffle the training split. + if split_name == 'train': + dataset_split = dataset_split.shuffle(100, reshuffle_each_iteration=True) + dataset_split = dataset_split.repeat() + + # Batch and pad each split. + batching_fn = functools.partial( + jraph.dynamically_batch, + graphs_tuple_iterator=iter(dataset_split), + n_node=budget.n_node, + n_edge=budget.n_edge, + n_graph=budget.n_graph) + dataset_split = tf.data.Dataset.from_generator( + batching_fn, + output_signature=padded_graphs_spec) + + # We cache the validation and test sets, since these are small. + if split_name in ['validation', 'test']: + dataset_split = dataset_split.cache() + + datasets[split_name] = dataset_split + return datasets + + +def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], + add_virtual_node: bool, + add_undirected_edges: bool, + add_self_loops: bool) -> jraph.GraphsTuple: + """Converts a dictionary of tf.Tensors to a GraphsTuple.""" + num_nodes = tf.squeeze(graph['num_nodes']) + num_edges = tf.squeeze(graph['num_edges']) + nodes = graph['node_feat'] + edges = graph['edge_feat'] + edge_feature_dim = edges.shape[-1] + labels = graph['labels'] + senders = graph['edge_index'][:, 0] + receivers = graph['edge_index'][:, 1] + + # Add a virtual node connected to all other nodes. + # The feature vectors for the virtual node + # and the new edges are set to all zeros. + if add_virtual_node: + nodes = tf.concat( + [nodes, tf.zeros_like(nodes[0, None])], axis=0) + senders = tf.concat( + [senders, tf.range(num_nodes)], axis=0) + receivers = tf.concat( + [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) + edges = tf.concat( + [edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) + num_edges += num_nodes + num_nodes += 1 + + # Make edges undirected, by adding edges with senders and receivers flipped. + # The feature vector for the flipped edge is the same as the original edge. + if add_undirected_edges: + new_senders = tf.concat([senders, receivers], axis=0) + new_receivers = tf.concat([receivers, senders], axis=0) + edges = tf.concat([edges, edges], axis=0) + senders, receivers = new_senders, new_receivers + num_edges *= 2 + + # Add self-loops for each node. + # The feature vectors for the self-loops are set to all zeros. + if add_self_loops: + senders = tf.concat([senders, tf.range(num_nodes)], axis=0) + receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) + edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) + num_edges += num_nodes + + return jraph.GraphsTuple( + n_node=tf.expand_dims(num_nodes, 0), + n_edge=tf.expand_dims(num_edges, 0), + nodes=nodes, + edges=edges, + senders=senders, + receivers=receivers, + globals=tf.expand_dims(labels, axis=0), + ) + + +def estimate_padding_budget_for_batch_size( + dataset: tf.data.Dataset, + batch_size: int, + num_estimation_graphs: int) -> GraphsTupleSize: + """Estimates the padding budget for a dataset of unbatched GraphsTuples. + + Args: + dataset: A dataset of unbatched GraphsTuples. + batch_size: The intended batch size. Note that no batching is performed by + this function. + num_estimation_graphs: How many graphs to take from the dataset to estimate + the distribution of number of nodes and edges per graph. + + Returns: + padding_budget: The padding budget for batching and padding the graphs + in this dataset to the given batch size. + """ + + def next_multiple_of_64(val: float): + """Returns the next multiple of 64 after val.""" + return 64 * (1 + int(val // 64)) + + if batch_size <= 1: + raise ValueError('Batch size must be > 1 to account for padding graphs.') + + total_num_nodes = 0 + total_num_edges = 0 + for graph in dataset.take(num_estimation_graphs).as_numpy_iterator(): + graph_size = get_graphs_tuple_size(graph) + if graph_size.n_graph != 1: + raise ValueError('Dataset contains batched GraphTuples.') + + total_num_nodes += graph_size.n_node + total_num_edges += graph_size.n_edge + + num_nodes_per_graph_estimate = total_num_nodes / num_estimation_graphs + num_edges_per_graph_estimate = total_num_edges / num_estimation_graphs + + padding_budget = GraphsTupleSize( + n_node=next_multiple_of_64(num_nodes_per_graph_estimate * batch_size), + n_edge=next_multiple_of_64(num_edges_per_graph_estimate * batch_size), + n_graph=batch_size) + return padding_budget + + +def specs_from_graphs_tuple(graph: jraph.GraphsTuple): + """Returns a tf.TensorSpec corresponding to this graph.""" + + def get_tensor_spec(array: np.ndarray): + shape = list(array.shape) + dtype = array.dtype + return tf.TensorSpec(shape=shape, dtype=dtype) + + specs = {} + for field in [ + 'nodes', 'edges', 'senders', 'receivers', 'globals', 'n_node', 'n_edge' + ]: + field_sample = getattr(graph, field) + specs[field] = get_tensor_spec(field_sample) + return jraph.GraphsTuple(**specs) + + +def get_graphs_tuple_size(graph: jraph.GraphsTuple): + """Returns the number of nodes, edges and graphs in a GraphsTuple.""" + return GraphsTupleSize( + n_node=np.sum(graph.n_node), + n_edge=np.sum(graph.n_edge), + n_graph=np.shape(graph.n_node)[0]) diff --git a/workloads/ogb/ogb_jax/models.py b/workloads/ogb/ogb_jax/models.py new file mode 100644 index 000000000..ce5d19d08 --- /dev/null +++ b/workloads/ogb/ogb_jax/models.py @@ -0,0 +1,178 @@ +# Forked from Flax example which can be found here: +# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/models.py + +"""Definition of the GNN model.""" + +from typing import Callable, Sequence + +from flax import linen as nn +import jax.numpy as jnp +import jraph + + +def add_graphs_tuples(graphs: jraph.GraphsTuple, + other_graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + """Adds the nodes, edges and global features from other_graphs to graphs.""" + return graphs._replace( + nodes=graphs.nodes + other_graphs.nodes, + edges=graphs.edges + other_graphs.edges, + globals=graphs.globals + other_graphs.globals) + + +class MLP(nn.Module): + """A multi-layer perceptron.""" + + feature_sizes: Sequence[int] + dropout_rate: float = 0 + deterministic: bool = True + activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, inputs): + x = inputs + for size in self.feature_sizes: + x = nn.Dense(features=size)(x) + x = self.activation(x) + x = nn.Dropout( + rate=self.dropout_rate, deterministic=self.deterministic)(x) + return x + + +class GraphNet(nn.Module): + """A complete Graph Network model defined with Jraph.""" + + latent_size: int + num_mlp_layers: int + message_passing_steps: int + output_globals_size: int + dropout_rate: float = 0 + skip_connections: bool = True + use_edge_model: bool = True + layer_norm: bool = True + deterministic: bool = True + + @nn.compact + def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + # We will first linearly project the original features as 'embeddings'. + embedder = jraph.GraphMapFeatures( + embed_node_fn=nn.Dense(self.latent_size), + embed_edge_fn=nn.Dense(self.latent_size), + embed_global_fn=nn.Dense(self.latent_size)) + processed_graphs = embedder(graphs) + + # Now, we will apply a Graph Network once for each message-passing round. + mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers + for _ in range(self.message_passing_steps): + if self.use_edge_model: + update_edge_fn = jraph.concatenated_args( + MLP(mlp_feature_sizes, + dropout_rate=self.dropout_rate, + deterministic=self.deterministic)) + else: + update_edge_fn = None + + update_node_fn = jraph.concatenated_args( + MLP(mlp_feature_sizes, + dropout_rate=self.dropout_rate, + deterministic=self.deterministic)) + update_global_fn = jraph.concatenated_args( + MLP(mlp_feature_sizes, + dropout_rate=self.dropout_rate, + deterministic=self.deterministic)) + + graph_net = jraph.GraphNetwork( + update_node_fn=update_node_fn, + update_edge_fn=update_edge_fn, + update_global_fn=update_global_fn) + + if self.skip_connections: + processed_graphs = add_graphs_tuples( + graph_net(processed_graphs), processed_graphs) + else: + processed_graphs = graph_net(processed_graphs) + + if self.layer_norm: + processed_graphs = processed_graphs._replace( + nodes=nn.LayerNorm()(processed_graphs.nodes), + edges=nn.LayerNorm()(processed_graphs.edges), + globals=nn.LayerNorm()(processed_graphs.globals), + ) + + # Since our graph-level predictions will be at globals, we will + # decode to get the required output logits. + decoder = jraph.GraphMapFeatures( + embed_global_fn=nn.Dense(self.output_globals_size)) + processed_graphs = decoder(processed_graphs) + + return processed_graphs + + +class GraphConvNet(nn.Module): + """A Graph Convolution Network + Pooling model defined with Jraph.""" + + latent_size: int + num_mlp_layers: int + message_passing_steps: int + output_globals_size: int + dropout_rate: float = 0 + skip_connections: bool = True + layer_norm: bool = True + deterministic: bool = True + pooling_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], + jnp.ndarray] = jraph.segment_mean + + def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + """Pooling operation, taken from Jraph.""" + + # Equivalent to jnp.sum(n_node), but JIT-able. + sum_n_node = graphs.nodes.shape[0] + # To aggregate nodes from each graph to global features, + # we first construct tensors that map the node to the corresponding graph. + # Example: if you have `n_node=[1,2]`, we construct the tensor [0, 1, 1]. + n_graph = graphs.n_node.shape[0] + node_graph_indices = jnp.repeat( + jnp.arange(n_graph), + graphs.n_node, + axis=0, + total_repeat_length=sum_n_node) + # We use the aggregation function to pool the nodes per graph. + pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) + return graphs._replace(globals=pooled) + + @nn.compact + def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + # We will first linearly project the original node features as 'embeddings'. + embedder = jraph.GraphMapFeatures( + embed_node_fn=nn.Dense(self.latent_size)) + processed_graphs = embedder(graphs) + + # Now, we will apply the GCN once for each message-passing round. + for _ in range(self.message_passing_steps): + mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers + update_node_fn = jraph.concatenated_args( + MLP(mlp_feature_sizes, + dropout_rate=self.dropout_rate, + deterministic=self.deterministic)) + graph_conv = jraph.GraphConvolution( + update_node_fn=update_node_fn, add_self_edges=True) + + if self.skip_connections: + processed_graphs = add_graphs_tuples( + graph_conv(processed_graphs), processed_graphs) + else: + processed_graphs = graph_conv(processed_graphs) + + if self.layer_norm: + processed_graphs = processed_graphs._replace( + nodes=nn.LayerNorm()(processed_graphs.nodes), + ) + + # We apply the pooling operation to get a 'global' embedding. + processed_graphs = self.pool(processed_graphs) + + # Now, we decode this to get the required output logits. + decoder = jraph.GraphMapFeatures( + embed_global_fn=nn.Dense(self.output_globals_size)) + processed_graphs = decoder(processed_graphs) + + return processed_graphs diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py new file mode 100644 index 000000000..eb7ef1895 --- /dev/null +++ b/workloads/ogb/ogb_jax/submission.py @@ -0,0 +1,92 @@ +from typing import List, Tuple + +import jax +import jax.numpy as jnp +import optax + +import spec + + +def get_batch_size(workload_name): + del workload_name + return 256 + + +def optimizer(hyperparameters: spec.Hyperparamters + ) -> optax.GradientTransformation: + """Creates an optimizer.""" + opt_init_fn, opt_update_fn = optax.adam( + learning_rate=hyperparameters.learning_rate) + return opt_init_fn, opt_update_fn + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparamters, + rng: spec.RandomState) -> spec.OptimizerState: + params_zeros_like = jax.tree_map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) + opt_init_fn, opt_update_fn = optimizer( + hyperparameters, workload.num_train_examples) + init_optimizer_state = opt_init_fn(params_zeros_like) + return init_optimizer_state, opt_update_fn + + +@jax.jit +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparamters, + input_batch: spec.Tensor, + label_batch: spec.Tensor, + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + def loss_fn(params): + logits_batch, new_model_state = self.model_fn( + params, + input_batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = workload.loss_fn(label_batch, logits_batch) + return loss, new_model_state + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, new_model_state), grad = grad_fn(current_param_container) + optimizer_state, opt_update_fn = optimizer_state + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return (new_optimizer_state, opt_update_fn), updated_params, new_model_state + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Tuple[spec.Tensor, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparamters, + global_step: int, + rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a single training example and label. + Return a tuple of input label batches. + """ + graphs = next(input_queue) + labels = graphs.globals + return x, labels diff --git a/workloads/ogb/ogb_jax/tuning_search_space.json b/workloads/ogb/ogb_jax/tuning_search_space.json new file mode 100644 index 000000000..622722c6c --- /dev/null +++ b/workloads/ogb/ogb_jax/tuning_search_space.json @@ -0,0 +1,3 @@ +{ + "learning_rate": {"feasible_points": [1e-3]}, +} diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py new file mode 100644 index 000000000..8c29e7dfc --- /dev/null +++ b/workloads/ogb/ogb_jax/workload.py @@ -0,0 +1,239 @@ +"""OGB workload implemented in Jax.""" + +from typing import Tuple +import sklearn.metrics + +import jax +import jax.numpy as jnp +from flax import linen as nn + +import spec +import input_pipeline +import models +from workloads.ogb.workload import OGB + + +class OGBWorkload(OGB): + + def __init__(self): + self._eval_ds = None + self._param_shapes = None + self._init_graphs = None + self._mask = None + self._model = models.GraphNet( + latent_size=256, + num_mlp_layers=2, + message_passing_steps=5, + output_globals_size=128, + dropout_rate=0.1, + skip_connections=True, + layer_norm=True, + use_edge_model=True, + deterministic=True) + + def _normalize(self, image): + pass + + def _build_dataset(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size): + datasets = input_pipeline.get_datasets( + batch_size, + add_virtual_node=False, + add_undirected_edges=True, + add_self_loops=True) + if self._init_graphs is None: + self._init_graphs = next(datasets['train'].as_numpy_iterator()) + return datasets[split] + + def build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int): + return iter(self._build_dataset(data_rng, split, data_dir, batch_size)) + + @property + def param_shapes(self): + if self._param_shapes is None: + raise ValueError( + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_shapes!') + return self._param_shapes + + def model_params_types(self): + pass + + # Return whether or not a key in spec.ParameterContainer is the output layer + # parameters. + def is_output_params(self, param_key: spec.ParameterKey) -> bool: + pass + + def preprocess_for_train( + self, + selected_raw_input_batch: spec.Tensor, + selected_label_batch: spec.Tensor, + train_mean: spec.Tensor, + train_stddev: spec.Tensor, + rng: spec.RandomState) -> spec.Tensor: + del train_mean + del train_stddev + del rng + return selected_raw_input_batch, selected_label_batch + + def preprocess_for_eval( + self, + raw_input_batch: spec.Tensor, + raw_label_batch: spec.Tensor, + train_mean: spec.Tensor, + train_stddev: spec.Tensor) -> spec.Tensor: + del train_mean + del train_stddev + return raw_input_batch, raw_label_batch + + def _replace_globals(self, graphs: jraph.Graphssklearn) -> jraph.Graphssklearn: + """Replaces the globals attribute with a constant feature for each graph.""" + return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) + + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + if self._init_graphs is None: + raise ValueError( + 'This should not happen, workload.build_input_queue() should be ' + 'called before workload.init_model_fn()!' + ) + rng, init_rng = jax.random.split(rng) + init_graphs = self._replace_globals(self._init_graphs) + params = jax.jit(self._model.init)(init_rng, init_graphs) + self._param_shapes = jax.tree_map( + lambda x: spec.ShapeTuple(x.shape), + params) + return params, None + + # Keep this separate from the loss function in order to support optimizers + # that use the logits. + def output_activation_fn( + self, + logits_batch: spec.Tensor, + loss_type: spec.LossType) -> spec.Tensor: + pass + + def _get_valid_mask(labels: jnp.ndarray, + graphs: jraph.GraphsTuple) -> jnp.ndarray: + """Gets the binary mask indicating only valid labels and graphs.""" + # We have to ignore all NaN values - which indicate labels for which + # the current graphs have no label. + labels_mask = ~jnp.isnan(labels) + + # Since we have extra 'dummy' graphs in our batch due to padding, we want + # to mask out any loss associated with the dummy graphs. + # Since we padded with `pad_with_graphs` we can recover the mask by using + # get_graph_padding_mask. + graph_mask = jraph.get_graph_padding_mask(graphs) + + # Combine the mask over labels with the mask over graphs. + return labels_mask & graph_mask[:, None] + + def model_fn( + self, + params: spec.ParameterContainer, + input_batch: spec.Tensor, + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + """Get predicted logits from the network for input graphs.""" + # Extract labels. + labels = input_batch.globals + # Replace the global feature for graph classification. + graphs = self._replace_globals(input_batch) + + # Get predicted logits + variables = {'params': params, **model_state} + train = mode == spec.ForwardPassMode.TRAIN + pred_graphs = self._model.apply( + variables, + graphs, + train=train, + rngs=rng) + logits = pred_graphs.globals + + # Get the mask for valid labels and graphs. + self._mask = self._get_valid_mask(labels, graphs) + + return logits, None + + def _binary_cross_entropy_with_mask( + self, + logits: jnp.ndarray, + labels: jnp.ndarray, + mask: jnp.ndarray) -> jnp.ndarray: + """Binary cross entropy loss for logits, with masked elements.""" + assert logits.shape == labels.shape == mask.shape + assert len(logits.shape) == 2 + + # To prevent propagation of NaNs during grad(). + # We mask over the loss for invalid targets later. + labels = jnp.where(mask, labels, -1) + + # Numerically stable implementation of BCE loss. + # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). + positive_logits = (logits >= 0) + relu_logits = jnp.where(positive_logits, logits, 0) + abs_logits = jnp.where(positive_logits, logits, -logits) + return relu_logits - (logits * labels) + ( + jnp.log(1 + jnp.exp(-abs_logits))) + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor) -> spec.Tensor: # differentiable + if self._mask is None: + raise ValueError( + 'This should not happen, workload.model_fn() should be ' + 'called before workload.loss_fn()!' + ) + loss = self._binary_cross_entropy_with_mask( + logits=logits_batch, labels=label_batch, mask=self._mask) + mean_loss = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask) + return mean_loss + + def _compute_mean_average_precision(self, labels, logits): + """Computes the mean average precision (mAP) over different tasks.""" + # Matches the official OGB evaluation scheme for mean average precision. + assert logits.shape == labels.shape == self._mask.shape + assert len(logits.shape) == 2 + + probs = jax.nn.sigmoid(logits) + num_tasks = labels.shape[1] + average_precisions = np.full(num_tasks, np.nan) + + for task in range(num_tasks): + # AP is only defined when there is at least one negative data + # and at least one positive data. + if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: + is_labeled = self._mask[:, task] + average_precisions[task] = sklearn.metrics.average_precision_score( + labels[is_labeled, task], probs[is_labeled, task]) + + # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. + if np.isnan(average_precisions).all(): + return np.nan + return np.nanmean(average_precisions) + + def _eval_metric(self, labels, logits): + """Return the accuracy, average precision, and loss as a dict.""" + preds = (logits > 0) + accuracy = np.nanmean((preds == labels).astype(jnp.float32)) + average_precision = self._compute_mean_average_precision(labels, logits) + loss = self.loss_fn(labels, logits) + metrics = { + 'accuracy': accuracy, + 'average_precision': average_precision, + 'loss': loss, + } + return metrics diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py new file mode 100644 index 000000000..27396e821 --- /dev/null +++ b/workloads/ogb/workload.py @@ -0,0 +1,75 @@ +import random_utils as prng +import spec + + +class OGB(spec.Workload): + + def has_reached_goal(self, eval_result: float) -> bool: + return eval_result['average_precision'] > self.target_value + + @property + def target_value(self): + return 0.24 + + @property + def loss_type(self): + return spec.LossType.SIGMOID_CROSS_ENTROPY + + @property + def num_train_examples(self): + return 350343 + + @property + def num_eval_examples(self): + return 43793 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self): + raise NotImplementedError + + @property + def eval_period_time_sec(self): + raise NotImplementedError + + def eval_model( + self, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str): + """Run a full evaluation of the model.""" + data_rng, model_rng = prng.split(rng, 2) + eval_batch_size = 256 + num_batches = self.num_eval_examples // eval_batch_size + if self._eval_ds is None: + self._eval_ds = self.build_input_queue( + data_rng, 'test', data_dir, batch_size=eval_batch_size) + + total_metrics = { + 'accuracy': 0., + 'average_precision': 0., + 'loss': 0., + } + # Loop over graphs. + for graphs in self._eval_ds.as_numpy_iterator(): + logits, _ = self.model_fn( + params, + graphs, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False) + labels = graphs.globals + batch_metrics = self._eval_metric(labels, logits) + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } + return {k: float(v / num_batches) for k, v in total_metrics.items()} From 89dd5e2f8bfeb239fe44ea2a09fcacfeb54265be Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Thu, 28 Oct 2021 15:08:55 +0200 Subject: [PATCH 03/46] Fix missing indent --- workloads/ogb/ogb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 8c29e7dfc..d8094d3ce 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -95,7 +95,7 @@ def preprocess_for_eval( return raw_input_batch, raw_label_batch def _replace_globals(self, graphs: jraph.Graphssklearn) -> jraph.Graphssklearn: - """Replaces the globals attribute with a constant feature for each graph.""" + """Replaces the globals attribute with a constant feature for each graph.""" return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: From 00f2118f76799b825f7f81dabf6e8e69b9f39375 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Thu, 28 Oct 2021 18:53:20 +0200 Subject: [PATCH 04/46] Fix imports; change model; fix typo --- workloads/ogb/ogb_jax/submission.py | 2 +- workloads/ogb/ogb_jax/workload.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index eb7ef1895..11b4c1ba7 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Iterator, List, Tuple import jax import jax.numpy as jnp diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index d8094d3ce..147922d00 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -5,12 +5,13 @@ import jax import jax.numpy as jnp +import jraph from flax import linen as nn import spec -import input_pipeline -import models from workloads.ogb.workload import OGB +from workloads.ogb.ogb_jax import input_pipeline +from workloads.ogb.ogb_jax import models class OGBWorkload(OGB): @@ -20,7 +21,7 @@ def __init__(self): self._param_shapes = None self._init_graphs = None self._mask = None - self._model = models.GraphNet( + self._model = models.GraphConvNet( latent_size=256, num_mlp_layers=2, message_passing_steps=5, @@ -28,7 +29,6 @@ def __init__(self): dropout_rate=0.1, skip_connections=True, layer_norm=True, - use_edge_model=True, deterministic=True) def _normalize(self, image): @@ -94,7 +94,7 @@ def preprocess_for_eval( del train_stddev return raw_input_batch, raw_label_batch - def _replace_globals(self, graphs: jraph.Graphssklearn) -> jraph.Graphssklearn: + def _replace_globals(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Replaces the globals attribute with a constant feature for each graph.""" return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) From f78040ef90128d406e6744db410d8b829380c59a Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Tue, 2 Nov 2021 19:13:58 -0400 Subject: [PATCH 05/46] znado jax cleanup of submission runner and GNN --- submission_runner.py | 58 ++++++++--------- workloads/ogb/ogb_jax/submission.py | 65 ++++++++++++------- .../ogb/ogb_jax/tuning_search_space.json | 4 +- workloads/ogb/ogb_jax/workload.py | 11 ++-- 4 files changed, 75 insertions(+), 63 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 5e4d2746f..47886525a 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -94,9 +94,9 @@ def _convert_filepath_to_module(path: str): def _import_workload( - workload_path, - workload_registry_name, - workload_class_name): + workload_path: str, + workload_registry_name: str, + workload_class_name: str) -> spec.Workload: """Import and add the workload to the registry. This importlib loading is nice to have because it allows runners to avoid @@ -115,26 +115,21 @@ def _import_workload( # Remove the trailing '.py' and convert the filepath to a Python module. workload_path = _convert_filepath_to_module(workload_path) - try: - # Import the workload module. - workload_module = importlib.import_module(workload_path) - # Get everything defined in the workload module (including our class). - workload_module_members = inspect.getmembers(workload_module) - workload_class = None - for name, value in workload_module_members: - if name == workload_class_name: - workload_class = value - break - if workload_class is None: - raise ValueError( - f'Could not find member {workload_class_name} in {workload_path}. ' - 'Make sure the Workload class is spelled correctly and defined in ' - 'the top scope of the module.') - WORKLOADS[workload_registry_name] = workload_class() - except ModuleNotFoundError as err: - logging.warning( - f'Could not import workload module {workload_path}, ' - f'continuing:\n\n{err}\n') + # Import the workload module. + workload_module = importlib.import_module(workload_path) + # Get everything defined in the workload module (including our class). + workload_module_members = inspect.getmembers(workload_module) + workload_class = None + for name, value in workload_module_members: + if name == workload_class_name: + workload_class = value + break + if workload_class is None: + raise ValueError( + f'Could not find member {workload_class_name} in {workload_path}. ' + 'Make sure the Workload class is spelled correctly and defined in ' + 'the top scope of the module.') + return workload_class() # Example reference implementation showing how to use the above functions @@ -192,7 +187,7 @@ def train_once( optimizer_state, model_params, model_state = update_params( workload=workload, current_param_container=model_params, - current_params_types=workload.model_params_types, + current_params_types=workload.model_params_types(), model_state=model_state, hyperparameters=hyperparameters, input_batch=selected_train_input_batch, @@ -225,6 +220,7 @@ def train_once( def score_submission_on_workload( + workload: spec.Workload, workload_name: str, submission_path: str, data_dir: str, @@ -241,8 +237,6 @@ def score_submission_on_workload( get_batch_size = submission_module.get_batch_size batch_size = get_batch_size(workload_name) - workload = WORKLOADS[workload_name] - if tuning_ruleset == 'external': # If the submission runner is responsible for hyperparameter tuning, load in # the search space and generate a list of randomly selected hyperparameter @@ -310,14 +304,14 @@ def main(_): # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') - for workload_name, workload in WORKLOADS.items(): - _import_workload( - workload_path=workload['workload_path'], - workload_registry_name=workload_name, - workload_class_name=workload['workload_class_name'] - ) + workload_metadata = WORKLOADS[FLAGS.workload] + workload = _import_workload( + workload_path=workload_metadata['workload_path'], + workload_registry_name=FLAGS.workload, + workload_class_name=workload_metadata['workload_class_name']) score = score_submission_on_workload( + workload, FLAGS.workload, FLAGS.submission_path, FLAGS.data_dir, diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index 11b4c1ba7..bb226c3fc 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -1,5 +1,6 @@ from typing import Iterator, List, Tuple +import functools import jax import jax.numpy as jnp import optax @@ -12,8 +13,7 @@ def get_batch_size(workload_name): return 256 -def optimizer(hyperparameters: spec.Hyperparamters - ) -> optax.GradientTransformation: +def optimizer(hyperparameters: spec.Hyperparamters) -> optax.GradientTransformation: """Creates an optimizer.""" opt_init_fn, opt_update_fn = optax.adam( learning_rate=hyperparameters.learning_rate) @@ -28,13 +28,37 @@ def init_optimizer_state( rng: spec.RandomState) -> spec.OptimizerState: params_zeros_like = jax.tree_map( lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer( - hyperparameters, workload.num_train_examples) + opt_init_fn, opt_update_fn = optimizer(hyperparameters) init_optimizer_state = opt_init_fn(params_zeros_like) return init_optimizer_state, opt_update_fn -@jax.jit +# We need to jax.pmap here instead of inside update_params because the latter +# would recompile the function every step. +@functools.partial( + jax.jit, + static_argnums=(0, 1)) +def pmapped_train_step(workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, input_batch, label_batch, rng): + def loss_fn(params): + logits_batch, new_model_state = workload.model_fn( + params, + input_batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = workload.loss_fn(label_batch, logits_batch) + return loss, new_model_state + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, new_model_state), grad = grad_fn(current_param_container) + optimizer_state, opt_update_fn = optimizer_state + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_model_state, (new_optimizer_state, opt_update_fn), updated_params + def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, @@ -43,8 +67,8 @@ def update_params( hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, - # This will define the output activation via `output_activation_fn`. loss_type: spec.LossType, + # This will define the output activation via `output_activation_fn`. optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, @@ -54,26 +78,19 @@ def update_params( del loss_type del eval_results del global_step + print('\n'*10, current_param_container) - def loss_fn(params): - logits_batch, new_model_state = self.model_fn( - params, - input_batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss = workload.loss_fn(label_batch, logits_batch) - return loss, new_model_state - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, new_model_state), grad = grad_fn(current_param_container) optimizer_state, opt_update_fn = optimizer_state - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return (new_optimizer_state, opt_update_fn), updated_params, new_model_state + new_model_state, new_optimizer_state, new_params = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, input_batch, label_batch, rng) + + steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet') + if (global_step + 1) % steps_per_epoch == 0: + # sync batch statistics across replicas once per epoch + new_model_state = workload.sync_batch_stats(new_model_state) + return new_model_state, new_optimizer_state, new_params def data_selection( workload: spec.Workload, @@ -89,4 +106,4 @@ def data_selection( """ graphs = next(input_queue) labels = graphs.globals - return x, labels + return graphs, labels diff --git a/workloads/ogb/ogb_jax/tuning_search_space.json b/workloads/ogb/ogb_jax/tuning_search_space.json index 622722c6c..7aba31610 100644 --- a/workloads/ogb/ogb_jax/tuning_search_space.json +++ b/workloads/ogb/ogb_jax/tuning_search_space.json @@ -1,3 +1 @@ -{ - "learning_rate": {"feasible_points": [1e-3]}, -} +{"learning_rate": {"feasible_points": [1e-3]}} \ No newline at end of file diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 147922d00..ed2196a29 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -54,7 +54,7 @@ def build_input_queue( split: str, data_dir: str, batch_size: int): - return iter(self._build_dataset(data_rng, split, data_dir, batch_size)) + return self._build_dataset(data_rng, split, data_dir, batch_size).as_numpy_iterator() @property def param_shapes(self): @@ -120,6 +120,10 @@ def output_activation_fn( loss_type: spec.LossType) -> spec.Tensor: pass + @property + def loss_type(self): + return spec.LossType.SOFTMAX_CROSS_ENTROPY + def _get_valid_mask(labels: jnp.ndarray, graphs: jraph.GraphsTuple) -> jnp.ndarray: """Gets the binary mask indicating only valid labels and graphs.""" @@ -151,13 +155,12 @@ def model_fn( graphs = self._replace_globals(input_batch) # Get predicted logits - variables = {'params': params, **model_state} + variables = {'params': params}#, **model_state} DO NOT SUBMIT train = mode == spec.ForwardPassMode.TRAIN pred_graphs = self._model.apply( variables, graphs, - train=train, - rngs=rng) + rngs={'dropout': rng}) logits = pred_graphs.globals # Get the mask for valid labels and graphs. From ddfede91273e6bfa33b4e19d06cb76f275a6b631 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Wed, 3 Nov 2021 18:50:57 +0100 Subject: [PATCH 06/46] Fix multiple issues --- workloads/ogb/ogb_jax/submission.py | 15 +++++++-------- workloads/ogb/ogb_jax/workload.py | 24 ++++++++++++++---------- workloads/ogb/workload.py | 2 ++ 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index bb226c3fc..ee8627650 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -1,6 +1,7 @@ from typing import Iterator, List, Tuple import functools +import numpy as np import jax import jax.numpy as jnp import optax @@ -9,8 +10,8 @@ def get_batch_size(workload_name): - del workload_name - return 256 + batch_sizes = {'ogb_jax': 256} + return batch_sizes[workload_name] def optimizer(hyperparameters: spec.Hyperparamters) -> optax.GradientTransformation: @@ -53,11 +54,10 @@ def loss_fn(params): grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) - optimizer_state, opt_update_fn = optimizer_state updates, new_optimizer_state = opt_update_fn( grad, optimizer_state, current_param_container) updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, (new_optimizer_state, opt_update_fn), updated_params + return new_model_state, new_optimizer_state, updated_params def update_params( workload: spec.Workload, @@ -77,7 +77,6 @@ def update_params( del current_params_types del loss_type del eval_results - del global_step print('\n'*10, current_param_container) optimizer_state, opt_update_fn = optimizer_state @@ -85,12 +84,12 @@ def update_params( workload, opt_update_fn, model_state, optimizer_state, current_param_container, hyperparameters, input_batch, label_batch, rng) - steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet') + steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') if (global_step + 1) % steps_per_epoch == 0: # sync batch statistics across replicas once per epoch new_model_state = workload.sync_batch_stats(new_model_state) - return new_model_state, new_optimizer_state, new_params + return (new_optimizer_state, opt_update_fn), new_params, new_model_state def data_selection( workload: spec.Workload, @@ -104,6 +103,6 @@ def data_selection( Each element of the queue is a single training example and label. Return a tuple of input label batches. """ - graphs = next(input_queue) + graphs = jax.tree_map(np.asarray, next(input_queue)) labels = graphs.globals return graphs, labels diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index ed2196a29..6ebdb9817 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -34,16 +34,17 @@ def __init__(self): def _normalize(self, image): pass - def _build_dataset(self, + def _build_dataset( + self, data_rng: jax.random.PRNGKey, split: str, data_dir: str, - batch_size): + batch_size: int): datasets = input_pipeline.get_datasets( - batch_size, - add_virtual_node=False, - add_undirected_edges=True, - add_self_loops=True) + batch_size, + add_virtual_node=False, + add_undirected_edges=True, + add_self_loops=True) if self._init_graphs is None: self._init_graphs = next(datasets['train'].as_numpy_iterator()) return datasets[split] @@ -107,6 +108,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, init_rng = jax.random.split(rng) init_graphs = self._replace_globals(self._init_graphs) params = jax.jit(self._model.init)(init_rng, init_graphs) + self._model.deterministic = False self._param_shapes = jax.tree_map( lambda x: spec.ShapeTuple(x.shape), params) @@ -124,8 +126,10 @@ def output_activation_fn( def loss_type(self): return spec.LossType.SOFTMAX_CROSS_ENTROPY - def _get_valid_mask(labels: jnp.ndarray, - graphs: jraph.GraphsTuple) -> jnp.ndarray: + def _get_valid_mask( + self, + labels: jnp.ndarray, + graphs: jraph.GraphsTuple) -> jnp.ndarray: """Gets the binary mask indicating only valid labels and graphs.""" # We have to ignore all NaN values - which indicate labels for which # the current graphs have no label. @@ -158,7 +162,7 @@ def model_fn( variables = {'params': params}#, **model_state} DO NOT SUBMIT train = mode == spec.ForwardPassMode.TRAIN pred_graphs = self._model.apply( - variables, + variables['params'], graphs, rngs={'dropout': rng}) logits = pred_graphs.globals @@ -202,7 +206,7 @@ def loss_fn( ) loss = self._binary_cross_entropy_with_mask( logits=logits_batch, labels=label_batch, mask=self._mask) - mean_loss = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask) + mean_loss = jnp.sum(jnp.where(self._mask, loss, 0)) / jnp.sum(self._mask) return mean_loss def _compute_mean_average_precision(self, labels, logits): diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py index 27396e821..1061c777d 100644 --- a/workloads/ogb/workload.py +++ b/workloads/ogb/workload.py @@ -53,6 +53,8 @@ def eval_model( self._eval_ds = self.build_input_queue( data_rng, 'test', data_dir, batch_size=eval_batch_size) + self._model.deterministic = True + total_metrics = { 'accuracy': 0., 'average_precision': 0., From 703639417c92796eb5e30b0dcb0849fd4bf0bc6f Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Wed, 3 Nov 2021 20:08:15 +0100 Subject: [PATCH 07/46] Add missing import --- workloads/ogb/ogb_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 6ebdb9817..4414de100 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -1,6 +1,7 @@ """OGB workload implemented in Jax.""" from typing import Tuple +import numpy as np import sklearn.metrics import jax From 5a8565d71eae3a9b1019a5ac9bf89f6df82564bd Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Wed, 3 Nov 2021 20:50:19 +0100 Subject: [PATCH 08/46] Add time limits; minor fixes --- workloads/ogb/ogb_jax/submission.py | 9 ++++----- workloads/ogb/workload.py | 8 ++++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index ee8627650..166b4a4b2 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -77,17 +77,16 @@ def update_params( del current_params_types del loss_type del eval_results - print('\n'*10, current_param_container) optimizer_state, opt_update_fn = optimizer_state new_model_state, new_optimizer_state, new_params = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, current_param_container, hyperparameters, input_batch, label_batch, rng) - steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') - if (global_step + 1) % steps_per_epoch == 0: - # sync batch statistics across replicas once per epoch - new_model_state = workload.sync_batch_stats(new_model_state) + #steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') + #if (global_step + 1) % steps_per_epoch == 0: + # # sync batch statistics across replicas once per epoch + # new_model_state = workload.sync_batch_stats(new_model_state) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py index 1061c777d..da075bbf3 100644 --- a/workloads/ogb/workload.py +++ b/workloads/ogb/workload.py @@ -9,7 +9,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 0.24 + return 0.255 @property def loss_type(self): @@ -33,11 +33,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self): - raise NotImplementedError + return 12000 # 3h20m @property def eval_period_time_sec(self): - raise NotImplementedError + return 360 # 60 minutes (too long) def eval_model( self, @@ -61,7 +61,7 @@ def eval_model( 'loss': 0., } # Loop over graphs. - for graphs in self._eval_ds.as_numpy_iterator(): + for graphs in self._eval_ds: logits, _ = self.model_fn( params, graphs, From a05198f235fa3c8efe4f7c7107718b1909ff893f Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Thu, 4 Nov 2021 01:31:49 +0100 Subject: [PATCH 09/46] Change validation set; allow non-deterministic training --- workloads/ogb/ogb_jax/submission.py | 1 + workloads/ogb/workload.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index 166b4a4b2..a903a3a08 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -78,6 +78,7 @@ def update_params( del loss_type del eval_results + workload._model.deterministic = False optimizer_state, opt_update_fn = optimizer_state new_model_state, new_optimizer_state, new_params = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py index da075bbf3..605b41799 100644 --- a/workloads/ogb/workload.py +++ b/workloads/ogb/workload.py @@ -50,8 +50,8 @@ def eval_model( eval_batch_size = 256 num_batches = self.num_eval_examples // eval_batch_size if self._eval_ds is None: - self._eval_ds = self.build_input_queue( - data_rng, 'test', data_dir, batch_size=eval_batch_size) + self._eval_ds = self._build_dataset( + data_rng, 'validation', data_dir, batch_size=eval_batch_size) self._model.deterministic = True @@ -61,7 +61,7 @@ def eval_model( 'loss': 0., } # Loop over graphs. - for graphs in self._eval_ds: + for graphs in self._eval_ds.as_numpy_iterator(): logits, _ = self.model_fn( params, graphs, From 436bcff1767bbc49f41b48adb8527f66bce77438 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Sun, 7 Nov 2021 21:05:48 +0100 Subject: [PATCH 10/46] Fix all eval metrics --- workloads/ogb/ogb_jax/metrics.py | 58 +++++++++++++++++++++++++++++ workloads/ogb/ogb_jax/submission.py | 3 +- workloads/ogb/ogb_jax/workload.py | 39 ++----------------- workloads/ogb/workload.py | 19 ++++------ 4 files changed, 71 insertions(+), 48 deletions(-) create mode 100644 workloads/ogb/ogb_jax/metrics.py diff --git a/workloads/ogb/ogb_jax/metrics.py b/workloads/ogb/ogb_jax/metrics.py new file mode 100644 index 000000000..190e22998 --- /dev/null +++ b/workloads/ogb/ogb_jax/metrics.py @@ -0,0 +1,58 @@ +# Forked from Flax example which can be found here: +# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py + +import numpy as np +import jax +import jax.numpy as jnp +import flax +from clu import metrics +from sklearn.metrics import average_precision_score + + +def predictions_match_labels(*, logits: jnp.ndarray, labels: jnp.ndarray, + **kwargs) -> jnp.ndarray: + """Returns a binary array indicating where predictions match the labels.""" + del kwargs # Unused. + preds = (logits > 0) + return (preds == labels).astype(jnp.float32) + + +@flax.struct.dataclass +class MeanAveragePrecision( + metrics.CollectingMetric.from_outputs(('labels', 'logits', 'mask'))): + """Computes the mean average precision (mAP) over different tasks.""" + + def compute(self): + # Matches the official OGB evaluation scheme for mean average precision. + labels = self.values['labels'] + logits = self.values['logits'] + mask = self.values['mask'] + + assert logits.shape == labels.shape == mask.shape + assert len(logits.shape) == 2 + + probs = jax.nn.sigmoid(logits) + num_tasks = labels.shape[1] + average_precisions = np.full(num_tasks, np.nan) + + for task in range(num_tasks): + # AP is only defined when there is at least one negative data + # and at least one positive data. + if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: + is_labeled = mask[:, task] + average_precisions[task] = average_precision_score( + labels[is_labeled, task], probs[is_labeled, task]) + + # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. + if np.isnan(average_precisions).all(): + return np.nan + return np.nanmean(average_precisions) + + +@flax.struct.dataclass +class EvalMetrics(metrics.Collection): + + accuracy: metrics.Average.from_fun(predictions_match_labels) + loss: metrics.Average.from_output('loss') + mean_average_precision: MeanAveragePrecision + diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index a903a3a08..26a6af1ff 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -50,7 +50,8 @@ def loss_fn(params): rng, update_batch_norm=True) loss = workload.loss_fn(label_batch, logits_batch) - return loss, new_model_state + mean_loss = jnp.sum(jnp.where(workload._mask, loss, 0)) / jnp.sum(workload._mask) + return mean_loss, new_model_state grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 4414de100..b5660894d 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -13,6 +13,7 @@ from workloads.ogb.workload import OGB from workloads.ogb.ogb_jax import input_pipeline from workloads.ogb.ogb_jax import models +from workloads.ogb.ogb_jax import metrics class OGBWorkload(OGB): @@ -207,41 +208,9 @@ def loss_fn( ) loss = self._binary_cross_entropy_with_mask( logits=logits_batch, labels=label_batch, mask=self._mask) - mean_loss = jnp.sum(jnp.where(self._mask, loss, 0)) / jnp.sum(self._mask) - return mean_loss - - def _compute_mean_average_precision(self, labels, logits): - """Computes the mean average precision (mAP) over different tasks.""" - # Matches the official OGB evaluation scheme for mean average precision. - assert logits.shape == labels.shape == self._mask.shape - assert len(logits.shape) == 2 - - probs = jax.nn.sigmoid(logits) - num_tasks = labels.shape[1] - average_precisions = np.full(num_tasks, np.nan) - - for task in range(num_tasks): - # AP is only defined when there is at least one negative data - # and at least one positive data. - if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: - is_labeled = self._mask[:, task] - average_precisions[task] = sklearn.metrics.average_precision_score( - labels[is_labeled, task], probs[is_labeled, task]) - - # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. - if np.isnan(average_precisions).all(): - return np.nan - return np.nanmean(average_precisions) + return loss def _eval_metric(self, labels, logits): - """Return the accuracy, average precision, and loss as a dict.""" - preds = (logits > 0) - accuracy = np.nanmean((preds == labels).astype(jnp.float32)) - average_precision = self._compute_mean_average_precision(labels, logits) loss = self.loss_fn(labels, logits) - metrics = { - 'accuracy': accuracy, - 'average_precision': average_precision, - 'loss': loss, - } - return metrics + return metrics.EvalMetrics.single_from_model_output( + loss=loss, logits=logits, labels=labels, mask=self._mask) diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py index 605b41799..e517cf358 100644 --- a/workloads/ogb/workload.py +++ b/workloads/ogb/workload.py @@ -5,11 +5,11 @@ class OGB(spec.Workload): def has_reached_goal(self, eval_result: float) -> bool: - return eval_result['average_precision'] > self.target_value + return eval_result['mean_average_precision'] > self.target_value @property def target_value(self): - return 0.255 + return 0.25 @property def loss_type(self): @@ -48,18 +48,13 @@ def eval_model( """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) eval_batch_size = 256 - num_batches = self.num_eval_examples // eval_batch_size if self._eval_ds is None: self._eval_ds = self._build_dataset( data_rng, 'validation', data_dir, batch_size=eval_batch_size) self._model.deterministic = True - total_metrics = { - 'accuracy': 0., - 'average_precision': 0., - 'loss': 0., - } + total_metrics = None # Loop over graphs. for graphs in self._eval_ds.as_numpy_iterator(): logits, _ = self.model_fn( @@ -71,7 +66,7 @@ def eval_model( update_batch_norm=False) labels = graphs.globals batch_metrics = self._eval_metric(labels, logits) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } - return {k: float(v / num_batches) for k, v in total_metrics.items()} + total_metrics = (batch_metrics if total_metrics is None + else total_metrics.merge(batch_metrics)) + return {k: float(v) for k, v in total_metrics.compute().items()} + From 7961ee670a2640c5a114d5fa102a033abef005b1 Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 7 Nov 2021 21:55:34 +0100 Subject: [PATCH 11/46] Move eval_model to jax-specific workload --- workloads/ogb/ogb_jax/workload.py | 33 +++++++++++++++++++++++++++++- workloads/ogb/workload.py | 34 +------------------------------ 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index b5660894d..826a6782b 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -129,7 +129,7 @@ def loss_type(self): return spec.LossType.SOFTMAX_CROSS_ENTROPY def _get_valid_mask( - self, + self, labels: jnp.ndarray, graphs: jraph.GraphsTuple) -> jnp.ndarray: """Gets the binary mask indicating only valid labels and graphs.""" @@ -214,3 +214,34 @@ def _eval_metric(self, labels, logits): loss = self.loss_fn(labels, logits) return metrics.EvalMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=self._mask) + + def eval_model( + self, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str): + """Run a full evaluation of the model.""" + data_rng, model_rng = prng.split(rng, 2) + eval_batch_size = 256 + if self._eval_ds is None: + self._eval_ds = self._build_dataset( + data_rng, 'validation', data_dir, batch_size=eval_batch_size) + + self._model.deterministic = True + + total_metrics = None + # Loop over graphs. + for graphs in self._eval_ds.as_numpy_iterator(): + logits, _ = self.model_fn( + params, + graphs, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False) + labels = graphs.globals + batch_metrics = self._eval_metric(labels, logits) + total_metrics = (batch_metrics if total_metrics is None + else total_metrics.merge(batch_metrics)) + return {k: float(v) for k, v in total_metrics.compute().items()} diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py index e517cf358..a33a674d3 100644 --- a/workloads/ogb/workload.py +++ b/workloads/ogb/workload.py @@ -37,36 +37,4 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 360 # 60 minutes (too long) - - def eval_model( - self, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str): - """Run a full evaluation of the model.""" - data_rng, model_rng = prng.split(rng, 2) - eval_batch_size = 256 - if self._eval_ds is None: - self._eval_ds = self._build_dataset( - data_rng, 'validation', data_dir, batch_size=eval_batch_size) - - self._model.deterministic = True - - total_metrics = None - # Loop over graphs. - for graphs in self._eval_ds.as_numpy_iterator(): - logits, _ = self.model_fn( - params, - graphs, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) - labels = graphs.globals - batch_metrics = self._eval_metric(labels, logits) - total_metrics = (batch_metrics if total_metrics is None - else total_metrics.merge(batch_metrics)) - return {k: float(v) for k, v in total_metrics.compute().items()} - + return 360 # 60 minutes (too long) From a598b39a11bbad0bab62669c5911d3cbc57509b3 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Tue, 9 Nov 2021 18:51:14 +0100 Subject: [PATCH 12/46] Fix imports --- workloads/ogb/ogb_jax/workload.py | 1 + workloads/ogb/workload.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 826a6782b..088c55061 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +import random_utils as prng import jraph from flax import linen as nn diff --git a/workloads/ogb/workload.py b/workloads/ogb/workload.py index a33a674d3..491ebcf86 100644 --- a/workloads/ogb/workload.py +++ b/workloads/ogb/workload.py @@ -1,4 +1,3 @@ -import random_utils as prng import spec From 3a1958b215ae43c2b3b015f5a43643a25c173c48 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen Date: Fri, 12 Nov 2021 00:39:09 +0100 Subject: [PATCH 13/46] Pmap train step (sth does not work) --- workloads/ogb/ogb_jax/submission.py | 28 ++++++++++++++--- workloads/ogb/ogb_jax/workload.py | 47 ++++++++++++++++++++++------- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index 26a6af1ff..63d37c043 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -4,6 +4,9 @@ import numpy as np import jax import jax.numpy as jnp +from jax import lax +from flax import jax_utils +import jraph import optax import spec @@ -31,14 +34,16 @@ def init_optimizer_state( lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optimizer(hyperparameters) init_optimizer_state = opt_init_fn(params_zeros_like) - return init_optimizer_state, opt_update_fn + return jax_utils.replicate(init_optimizer_state), opt_update_fn # We need to jax.pmap here instead of inside update_params because the latter # would recompile the function every step. @functools.partial( - jax.jit, - static_argnums=(0, 1)) + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0, None), + static_broadcasted_argnums=(0, 1)) def pmapped_train_step(workload, opt_update_fn, model_state, optimizer_state, current_param_container, hyperparameters, input_batch, label_batch, rng): def loss_fn(params): @@ -55,11 +60,13 @@ def loss_fn(params): grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') updates, new_optimizer_state = opt_update_fn( grad, optimizer_state, current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_model_state, new_optimizer_state, updated_params + def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, @@ -90,8 +97,13 @@ def update_params( # # sync batch statistics across replicas once per epoch # new_model_state = workload.sync_batch_stats(new_model_state) + #return ( + # (jax_utils.unreplicate(new_optimizer_state), opt_update_fn), + # jax_utils.unreplicate(new_params), + # jax_utils.replicate(new_model_state)) return (new_optimizer_state, opt_update_fn), new_params, new_model_state + def data_selection( workload: spec.Workload, input_queue: Iterator[Tuple[spec.Tensor, spec.Tensor]], @@ -104,6 +116,12 @@ def data_selection( Each element of the queue is a single training example and label. Return a tuple of input label batches. """ - graphs = jax.tree_map(np.asarray, next(input_queue)) - labels = graphs.globals + graphs = [] + labels = [] + for _ in range(jax.local_device_count()): + graph = jax.tree_map(np.asarray, next(input_queue)) + graphs.append(graph) + labels.append(graph.globals) + graphs = jax.tree_multimap(lambda *x: jnp.stack(x, axis=0), *graphs) + labels = jnp.stack(labels) return graphs, labels diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 088c55061..051293166 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -1,6 +1,7 @@ """OGB workload implemented in Jax.""" from typing import Tuple +import functools import numpy as np import sklearn.metrics @@ -9,6 +10,7 @@ import random_utils as prng import jraph from flax import linen as nn +from flax import jax_utils import spec from workloads.ogb.workload import OGB @@ -115,7 +117,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._param_shapes = jax.tree_map( lambda x: spec.ShapeTuple(x.shape), params) - return params, None + return jax_utils.replicate(params), None # Keep this separate from the loss function in order to support optimizers # that use the logits. @@ -211,11 +213,39 @@ def loss_fn( logits=logits_batch, labels=label_batch, mask=self._mask) return loss + def _shard_graphs(self, graphs): + graph_list = jraph.unbatch(graphs) + n_devices = jax.local_device_count() + local_batch_size = len(graph_list) // n_devices + graphs = [ + jax.tree_map( + np.asarray, + jraph.batch(graph_list[i*local_batch_size:(i+1)*local_batch_size])) + for i in range(n_devices)] + graphs = jax.tree_multimap(lambda *x: jnp.stack(x, axis=0), *graphs) + return graphs + def _eval_metric(self, labels, logits): loss = self.loss_fn(labels, logits) return metrics.EvalMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=self._mask) + #@functools.partial( + # jax.pmap, + # axis_name='batch', + # in_axes=(None, 0, 0, 0, None), + # static_broadcasted_argnums=(0,)) + def _eval_batch(self, params, graphs, model_state, rng): + labels = graphs.globals + logits, _ = self.model_fn( + params, + graphs, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + return self._eval_metric(labels, logits) + def eval_model( self, params: spec.ParameterContainer, @@ -230,19 +260,14 @@ def eval_model( data_rng, 'validation', data_dir, batch_size=eval_batch_size) self._model.deterministic = True + params = jax_utils.unreplicate(params) + model_state = jax_utils.unreplicate(model_state) total_metrics = None - # Loop over graphs. + # Loop over graph batches in eval dataset for graphs in self._eval_ds.as_numpy_iterator(): - logits, _ = self.model_fn( - params, - graphs, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) - labels = graphs.globals - batch_metrics = self._eval_metric(labels, logits) + # graphs = self._shard_graphs(graphs) + batch_metrics = self._eval_batch(params, graphs, model_state, model_rng) total_metrics = (batch_metrics if total_metrics is None else total_metrics.merge(batch_metrics)) return {k: float(v) for k, v in total_metrics.compute().items()} From 9dbe6d5c497f0c80c6e990e75658cd574edc631c Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Thu, 2 Dec 2021 21:47:50 -0500 Subject: [PATCH 14/46] refactoring gnn input pipeline to do all masking and replicating inside input_pipeline.py. various other refactors. adding an optional mask arg to the loss_fn API. --- setup.py | 7 +- spec.py | 5 +- submission_runner.py | 3 +- workloads/ogb/ogb_jax/input_pipeline.py | 72 ++++++++--- workloads/ogb/ogb_jax/metrics.py | 1 - workloads/ogb/ogb_jax/models.py | 14 +-- workloads/ogb/ogb_jax/submission.py | 41 +++--- workloads/ogb/ogb_jax/workload.py | 161 ++++++++++-------------- workloads/wmt/wmt_jax/workload.py | 5 +- 9 files changed, 157 insertions(+), 152 deletions(-) diff --git a/setup.py b/setup.py index 8ba06b803..27466dc1a 100644 --- a/setup.py +++ b/setup.py @@ -12,10 +12,11 @@ 'flax==0.3.5', 'optax==0.0.9', 'tensorflow_datasets==4.4.0', - 'tensorflow-cpu==2.5.0', + 'tensorflow==2.5.0', ] +# Assumes CUDA 11.x and a compatible NVIDIA driver and CuDNN. setup( name='algorithmic_efficiency', version='0.0.1', @@ -28,7 +29,11 @@ packages=find_packages(), install_requires=[ 'absl-py==0.14.0', + 'clu==0.0.6', + 'jraph==0.0.2.dev', 'numpy>=1.19.2', + 'pandas==1.3.4', + 'scikit-learn==1.0.1', ], extras_require={ 'jax-cpu': jax_core_deps + ['jaxlib==0.1.71'], diff --git a/spec.py b/spec.py index e5035087a..e1404b8cc 100644 --- a/spec.py +++ b/spec.py @@ -2,7 +2,7 @@ import enum import time -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import abc @@ -199,7 +199,8 @@ def output_activation_fn( def loss_fn( self, label_batch: Tensor, # Dense (not one-hot) labels. - logits_batch: Tensor) -> Tensor: # differentiable + logits_batch: Tensor, + mask_batch: Optional[Tensor]) -> Tensor: # differentiable """return oned_array_of_losses_per_example""" @abc.abstractmethod diff --git a/submission_runner.py b/submission_runner.py index 47886525a..3f7b0e4d1 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -175,7 +175,7 @@ def train_once( data_select_rng, update_rng, eval_rng = prng.split( step_rng, 3) start_time = time.time() - selected_train_input_batch, selected_train_label_batch = data_selection( + selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( workload, input_queue, optimizer_state, @@ -192,6 +192,7 @@ def train_once( hyperparameters=hyperparameters, input_batch=selected_train_input_batch, label_batch=selected_train_label_batch, + mask_batch=selected_train_mask_batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, eval_results=eval_results, diff --git a/workloads/ogb/ogb_jax/input_pipeline.py b/workloads/ogb/ogb_jax/input_pipeline.py index 8c44dda2b..aae4cef00 100644 --- a/workloads/ogb/ogb_jax/input_pipeline.py +++ b/workloads/ogb/ogb_jax/input_pipeline.py @@ -5,6 +5,7 @@ import functools from typing import Dict, NamedTuple +import jax import jraph import numpy as np import tensorflow as tf @@ -32,10 +33,50 @@ def get_raw_datasets() -> Dict[str, tf.data.Dataset]: return datasets -def get_datasets(batch_size: int, - add_virtual_node: bool = True, - add_undirected_edges: bool = True, - add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: +def _get_valid_mask(graphs: jraph.GraphsTuple): + """Gets the binary mask indicating only valid labels and graphs.""" + labels = graphs.globals + # We have to ignore all NaN values - which indicate labels for which + # the current graphs have no label. + labels_masks = ~np.isnan(labels) + + # Since we have extra 'dummy' graphs in our batch due to padding, we want + # to mask out any loss associated with the dummy graphs. + # Since we padded with `pad_with_graphs` we can recover the mask by using + # get_graph_padding_mask. + graph_masks = jraph.get_graph_padding_mask(graphs) + + # Combine the mask over labels with the mask over graphs. + masks = labels_masks & graph_masks[:, None] + graphs = graphs._replace(globals=[]) + return graphs, labels, masks + + +def _batch_for_pmap(iterator): + graphs = [] + labels = [] + masks = [] + count = 0 + for graph_batch, label_batch, mask_batch in iterator: + count += 1 + graphs.append(graph_batch) + labels.append(label_batch) + masks.append(mask_batch) + if count == jax.local_device_count(): + graphs = jax.tree_multimap(lambda *x: np.stack(x, axis=0), *graphs) + labels = np.stack(labels) + masks = np.stack(masks) + yield graphs, labels, masks + graphs = [] + labels = [] + masks = [] + count = 0 + + +def get_dataset_iters(batch_size: int, + add_virtual_node: bool = True, + add_undirected_edges: bool = True, + add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: """Returns datasets of batched GraphsTuples, organized by split.""" if batch_size <= 1: raise ValueError('Batch size must be > 1 to account for padding graphs.') @@ -53,7 +94,6 @@ def get_datasets(batch_size: int, # Process each split separately. for split_name in datasets: - # Convert to GraphsTuple. datasets[split_name] = datasets[split_name].map( convert_to_graphs_tuple_fn, @@ -77,23 +117,25 @@ def get_datasets(batch_size: int, if split_name == 'train': dataset_split = dataset_split.shuffle(100, reshuffle_each_iteration=True) dataset_split = dataset_split.repeat() + # We cache the validation and test sets, since these are small. + else: + dataset_split = dataset_split.cache() - # Batch and pad each split. - batching_fn = functools.partial( - jraph.dynamically_batch, + # Batch and pad each split. Note that this also converts the graphs to + # numpy. + batched_iter = jraph.dynamically_batch( graphs_tuple_iterator=iter(dataset_split), n_node=budget.n_node, n_edge=budget.n_edge, n_graph=budget.n_graph) - dataset_split = tf.data.Dataset.from_generator( - batching_fn, - output_signature=padded_graphs_spec) - # We cache the validation and test sets, since these are small. - if split_name in ['validation', 'test']: - dataset_split = dataset_split.cache() + # An iterator of Tuple[graph, labels, mask]. + masked_iter = map(_get_valid_mask, batched_iter) - datasets[split_name] = dataset_split + # An iterator the same as above, but where each element has an extra leading + # dim of size jax.local_device_count(). + pmapped_iterator = _batch_for_pmap(masked_iter) + datasets[split_name] = pmapped_iterator return datasets diff --git a/workloads/ogb/ogb_jax/metrics.py b/workloads/ogb/ogb_jax/metrics.py index 190e22998..7410500b8 100644 --- a/workloads/ogb/ogb_jax/metrics.py +++ b/workloads/ogb/ogb_jax/metrics.py @@ -51,7 +51,6 @@ def compute(self): @flax.struct.dataclass class EvalMetrics(metrics.Collection): - accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') mean_average_precision: MeanAveragePrecision diff --git a/workloads/ogb/ogb_jax/models.py b/workloads/ogb/ogb_jax/models.py index ce5d19d08..80ba51bde 100644 --- a/workloads/ogb/ogb_jax/models.py +++ b/workloads/ogb/ogb_jax/models.py @@ -49,10 +49,9 @@ class GraphNet(nn.Module): skip_connections: bool = True use_edge_model: bool = True layer_norm: bool = True - deterministic: bool = True @nn.compact - def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + def __call__(self, graphs: jraph.GraphsTuple, train: bool) -> jraph.GraphsTuple: # We will first linearly project the original features as 'embeddings'. embedder = jraph.GraphMapFeatures( embed_node_fn=nn.Dense(self.latent_size), @@ -67,18 +66,18 @@ def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: update_edge_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=not train)) else: update_edge_fn = None update_node_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=not train)) update_global_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=not train)) graph_net = jraph.GraphNetwork( update_node_fn=update_node_fn, @@ -117,7 +116,6 @@ class GraphConvNet(nn.Module): dropout_rate: float = 0 skip_connections: bool = True layer_norm: bool = True - deterministic: bool = True pooling_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray] = jraph.segment_mean @@ -140,7 +138,7 @@ def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: return graphs._replace(globals=pooled) @nn.compact - def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: + def __call__(self, graphs: jraph.GraphsTuple, train: bool) -> jraph.GraphsTuple: # We will first linearly project the original node features as 'embeddings'. embedder = jraph.GraphMapFeatures( embed_node_fn=nn.Dense(self.latent_size)) @@ -152,7 +150,7 @@ def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: update_node_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, - deterministic=self.deterministic)) + deterministic=not train)) graph_conv = jraph.GraphConvolution( update_node_fn=update_node_fn, add_self_edges=True) diff --git a/workloads/ogb/ogb_jax/submission.py b/workloads/ogb/ogb_jax/submission.py index 63d37c043..62caabf93 100644 --- a/workloads/ogb/ogb_jax/submission.py +++ b/workloads/ogb/ogb_jax/submission.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import functools import numpy as np @@ -42,10 +42,13 @@ def init_optimizer_state( @functools.partial( jax.pmap, axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), + in_axes=(None, None, 0, 0, 0, None, 0, 0, 0, None), static_broadcasted_argnums=(0, 1)) def pmapped_train_step(workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, input_batch, label_batch, rng): + current_param_container, hyperparameters, input_batch, + label_batch, mask_batch, rng): + del hyperparameters + def loss_fn(params): logits_batch, new_model_state = workload.model_fn( params, @@ -54,8 +57,8 @@ def loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True) - loss = workload.loss_fn(label_batch, logits_batch) - mean_loss = jnp.sum(jnp.where(workload._mask, loss, 0)) / jnp.sum(workload._mask) + loss = workload.loss_fn(label_batch, logits_batch, mask_batch) + mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) return mean_loss, new_model_state grad_fn = jax.value_and_grad(loss_fn, has_aux=True) @@ -75,6 +78,7 @@ def update_params( hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor], loss_type: spec.LossType, # This will define the output activation via `output_activation_fn`. optimizer_state: spec.OptimizerState, @@ -85,22 +89,18 @@ def update_params( del current_params_types del loss_type del eval_results + del global_step - workload._model.deterministic = False optimizer_state, opt_update_fn = optimizer_state new_model_state, new_optimizer_state, new_params = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, input_batch, label_batch, rng) + current_param_container, hyperparameters, input_batch, label_batch, + mask_batch, rng) #steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') #if (global_step + 1) % steps_per_epoch == 0: # # sync batch statistics across replicas once per epoch # new_model_state = workload.sync_batch_stats(new_model_state) - - #return ( - # (jax_utils.unreplicate(new_optimizer_state), opt_update_fn), - # jax_utils.unreplicate(new_params), - # jax_utils.replicate(new_model_state)) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -111,17 +111,6 @@ def data_selection( current_param_container: spec.ParameterContainer, hyperparameters: spec.Hyperparamters, global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a single training example and label. - Return a tuple of input label batches. - """ - graphs = [] - labels = [] - for _ in range(jax.local_device_count()): - graph = jax.tree_map(np.asarray, next(input_queue)) - graphs.append(graph) - labels.append(graph.globals) - graphs = jax.tree_multimap(lambda *x: jnp.stack(x, axis=0), *graphs) - labels = jnp.stack(labels) - return graphs, labels + rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue.""" + return next(input_queue) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index 051293166..dbe39c2cd 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -1,6 +1,6 @@ """OGB workload implemented in Jax.""" -from typing import Tuple +from typing import Optional, Tuple import functools import numpy as np import sklearn.metrics @@ -22,10 +22,9 @@ class OGBWorkload(OGB): def __init__(self): - self._eval_ds = None + self._eval_iterator = None self._param_shapes = None self._init_graphs = None - self._mask = None self._model = models.GraphConvNet( latent_size=256, num_mlp_layers=2, @@ -33,26 +32,24 @@ def __init__(self): output_globals_size=128, dropout_rate=0.1, skip_connections=True, - layer_norm=True, - deterministic=True) + layer_norm=True) - def _normalize(self, image): - pass - - def _build_dataset( + def _build_iterator( self, data_rng: jax.random.PRNGKey, split: str, data_dir: str, batch_size: int): - datasets = input_pipeline.get_datasets( + dataset_iters = input_pipeline.get_dataset_iters( batch_size, add_virtual_node=False, add_undirected_edges=True, add_self_loops=True) if self._init_graphs is None: - self._init_graphs = next(datasets['train'].as_numpy_iterator()) - return datasets[split] + init_graphs = next(dataset_iters['train'])[0] + # Unreplicate the iterator that has the leading dim for pmapping. + self._init_graphs = jax.tree_map(lambda x: x[0], init_graphs) + return dataset_iters[split] def build_input_queue( self, @@ -60,7 +57,7 @@ def build_input_queue( split: str, data_dir: str, batch_size: int): - return self._build_dataset(data_rng, split, data_dir, batch_size).as_numpy_iterator() + return self._build_iterator(data_rng, split, data_dir, batch_size) @property def param_shapes(self): @@ -100,20 +97,16 @@ def preprocess_for_eval( del train_stddev return raw_input_batch, raw_label_batch - def _replace_globals(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: - """Replaces the globals attribute with a constant feature for each graph.""" - return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: if self._init_graphs is None: raise ValueError( 'This should not happen, workload.build_input_queue() should be ' 'called before workload.init_model_fn()!' ) - rng, init_rng = jax.random.split(rng) - init_graphs = self._replace_globals(self._init_graphs) - params = jax.jit(self._model.init)(init_rng, init_graphs) - self._model.deterministic = False + rng, params_rng, dropout_rng = jax.random.split(rng, 3) + params = jax.jit(functools.partial(self._model.init, train=False))( + {'params': params_rng, 'dropout': dropout_rng}, self._init_graphs) + params = params['params'] self._param_shapes = jax.tree_map( lambda x: spec.ShapeTuple(x.shape), params) @@ -131,24 +124,6 @@ def output_activation_fn( def loss_type(self): return spec.LossType.SOFTMAX_CROSS_ENTROPY - def _get_valid_mask( - self, - labels: jnp.ndarray, - graphs: jraph.GraphsTuple) -> jnp.ndarray: - """Gets the binary mask indicating only valid labels and graphs.""" - # We have to ignore all NaN values - which indicate labels for which - # the current graphs have no label. - labels_mask = ~jnp.isnan(labels) - - # Since we have extra 'dummy' graphs in our batch due to padding, we want - # to mask out any loss associated with the dummy graphs. - # Since we padded with `pad_with_graphs` we can recover the mask by using - # get_graph_padding_mask. - graph_mask = jraph.get_graph_padding_mask(graphs) - - # Combine the mask over labels with the mask over graphs. - return labels_mask & graph_mask[:, None] - def model_fn( self, params: spec.ParameterContainer, @@ -158,29 +133,20 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" - # Extract labels. - labels = input_batch.globals - # Replace the global feature for graph classification. - graphs = self._replace_globals(input_batch) - - # Get predicted logits - variables = {'params': params}#, **model_state} DO NOT SUBMIT + assert model_state is None train = mode == spec.ForwardPassMode.TRAIN pred_graphs = self._model.apply( - variables['params'], - graphs, - rngs={'dropout': rng}) + {'params': params}, + input_batch, + rngs={'dropout': rng}, + train=train) logits = pred_graphs.globals - - # Get the mask for valid labels and graphs. - self._mask = self._get_valid_mask(labels, graphs) - return logits, None def _binary_cross_entropy_with_mask( self, - logits: jnp.ndarray, labels: jnp.ndarray, + logits: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: """Binary cross entropy loss for logits, with masked elements.""" assert logits.shape == labels.shape == mask.shape @@ -203,40 +169,47 @@ def _binary_cross_entropy_with_mask( def loss_fn( self, label_batch: spec.Tensor, - logits_batch: spec.Tensor) -> spec.Tensor: # differentiable - if self._mask is None: - raise ValueError( - 'This should not happen, workload.model_fn() should be ' - 'called before workload.loss_fn()!' - ) + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor]) -> spec.Tensor: # differentiable loss = self._binary_cross_entropy_with_mask( - logits=logits_batch, labels=label_batch, mask=self._mask) - return loss - - def _shard_graphs(self, graphs): - graph_list = jraph.unbatch(graphs) - n_devices = jax.local_device_count() - local_batch_size = len(graph_list) // n_devices - graphs = [ - jax.tree_map( - np.asarray, - jraph.batch(graph_list[i*local_batch_size:(i+1)*local_batch_size])) - for i in range(n_devices)] - graphs = jax.tree_multimap(lambda *x: jnp.stack(x, axis=0), *graphs) - return graphs + labels=label_batch, logits=logits_batch, mask=mask_batch) + mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) + return mean_loss + + def _compute_mean_average_precision(self, labels, logits, masks): + """Computes the mean average precision (mAP) over different tasks.""" + # Matches the official OGB evaluation scheme for mean average precision. + assert logits.shape == labels.shape == masks.shape + assert len(logits.shape) == 2 - def _eval_metric(self, labels, logits): - loss = self.loss_fn(labels, logits) + probs = jax.nn.sigmoid(logits) + num_tasks = labels.shape[1] + average_precisions = np.full(num_tasks, np.nan) + + for task in range(num_tasks): + # AP is only defined when there is at least one negative data + # and at least one positive data. + if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: + is_labeled = masks[:, task] + average_precisions[task] = sklearn.metrics.average_precision_score( + labels[is_labeled, task], probs[is_labeled, task]) + + # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. + if np.isnan(average_precisions).all(): + return np.nan + return np.nanmean(average_precisions) + + def _eval_metric(self, labels, logits, masks): + loss = self.loss_fn(labels, logits, masks) return metrics.EvalMetrics.single_from_model_output( - loss=loss, logits=logits, labels=labels, mask=self._mask) - - #@functools.partial( - # jax.pmap, - # axis_name='batch', - # in_axes=(None, 0, 0, 0, None), - # static_broadcasted_argnums=(0,)) - def _eval_batch(self, params, graphs, model_state, rng): - labels = graphs.globals + loss=loss, logits=logits, labels=labels, mask=masks) + + @functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, 0, 0, None), + static_broadcasted_argnums=(0,)) + def _eval_batch(self, params, graphs, labels, masks, model_state, rng): logits, _ = self.model_fn( params, graphs, @@ -244,7 +217,7 @@ def _eval_batch(self, params, graphs, model_state, rng): spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) - return self._eval_metric(labels, logits) + return self._eval_metric(labels, logits, masks) def eval_model( self, @@ -255,19 +228,15 @@ def eval_model( """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) eval_batch_size = 256 - if self._eval_ds is None: - self._eval_ds = self._build_dataset( + if self._eval_iterator is None: + self._eval_iterator = self._build_iterator( data_rng, 'validation', data_dir, batch_size=eval_batch_size) - self._model.deterministic = True - params = jax_utils.unreplicate(params) - model_state = jax_utils.unreplicate(model_state) - total_metrics = None # Loop over graph batches in eval dataset - for graphs in self._eval_ds.as_numpy_iterator(): - # graphs = self._shard_graphs(graphs) - batch_metrics = self._eval_batch(params, graphs, model_state, model_rng) + for graphs, labels, masks in self._eval_iterator: + batch_metrics = self._eval_batch( + params, graphs, labels, masks, model_state, model_rng) total_metrics = (batch_metrics if total_metrics is None else total_metrics.merge(batch_metrics)) - return {k: float(v) for k, v in total_metrics.compute().items()} + return {k: float(v) for k, v in total_metrics.reduce().compute().items()} diff --git a/workloads/wmt/wmt_jax/workload.py b/workloads/wmt/wmt_jax/workload.py index 41d3f306d..6f5acb883 100644 --- a/workloads/wmt/wmt_jax/workload.py +++ b/workloads/wmt/wmt_jax/workload.py @@ -418,8 +418,9 @@ def model_fn( def loss_fn( self, label_batch: spec.Tensor, # Dense (not one-hot) labels. - logits_batch: spec.Tensor) -> spec.Tensor: - + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor]) -> spec.Tensor: + del mask_batch weights = jnp.where(label_batch > 0, 1.0, 0.0) loss, _ = self.compute_weighted_cross_entropy(logits_batch, label_batch, weights) From 47ad6050188abf9024a94b034c0c1bbf754b4ae2 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Thu, 2 Dec 2021 22:15:08 -0500 Subject: [PATCH 15/46] temporarily removing some functionality from convert_to_graphs_tuple --- workloads/ogb/ogb_jax/input_pipeline.py | 62 ++++++++++++------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/workloads/ogb/ogb_jax/input_pipeline.py b/workloads/ogb/ogb_jax/input_pipeline.py index aae4cef00..27601afe7 100644 --- a/workloads/ogb/ogb_jax/input_pipeline.py +++ b/workloads/ogb/ogb_jax/input_pipeline.py @@ -153,37 +153,37 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], senders = graph['edge_index'][:, 0] receivers = graph['edge_index'][:, 1] - # Add a virtual node connected to all other nodes. - # The feature vectors for the virtual node - # and the new edges are set to all zeros. - if add_virtual_node: - nodes = tf.concat( - [nodes, tf.zeros_like(nodes[0, None])], axis=0) - senders = tf.concat( - [senders, tf.range(num_nodes)], axis=0) - receivers = tf.concat( - [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) - edges = tf.concat( - [edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) - num_edges += num_nodes - num_nodes += 1 - - # Make edges undirected, by adding edges with senders and receivers flipped. - # The feature vector for the flipped edge is the same as the original edge. - if add_undirected_edges: - new_senders = tf.concat([senders, receivers], axis=0) - new_receivers = tf.concat([receivers, senders], axis=0) - edges = tf.concat([edges, edges], axis=0) - senders, receivers = new_senders, new_receivers - num_edges *= 2 - - # Add self-loops for each node. - # The feature vectors for the self-loops are set to all zeros. - if add_self_loops: - senders = tf.concat([senders, tf.range(num_nodes)], axis=0) - receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) - edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) - num_edges += num_nodes + # # Add a virtual node connected to all other nodes. + # # The feature vectors for the virtual node + # # and the new edges are set to all zeros. + # if add_virtual_node: + # nodes = tf.concat( + # [nodes, tf.zeros_like(nodes[0, None])], axis=0) + # senders = tf.concat( + # [senders, tf.range(num_nodes)], axis=0) + # receivers = tf.concat( + # [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) + # edges = tf.concat( + # [edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) + # num_edges += num_nodes + # num_nodes += 1 + + # # Make edges undirected, by adding edges with senders and receivers flipped. + # # The feature vector for the flipped edge is the same as the original edge. + # if add_undirected_edges: + # new_senders = tf.concat([senders, receivers], axis=0) + # new_receivers = tf.concat([receivers, senders], axis=0) + # edges = tf.concat([edges, edges], axis=0) + # senders, receivers = new_senders, new_receivers + # num_edges *= 2 + + # # Add self-loops for each node. + # # The feature vectors for the self-loops are set to all zeros. + # if add_self_loops: + # senders = tf.concat([senders, tf.range(num_nodes)], axis=0) + # receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) + # edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) + # num_edges += num_nodes return jraph.GraphsTuple( n_node=tf.expand_dims(num_nodes, 0), From 86763d36780ab2d5e2b09ef3ae92513030df6372 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Thu, 2 Dec 2021 22:26:55 -0500 Subject: [PATCH 16/46] small hack for total_metrics --- workloads/ogb/ogb_jax/workload.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/workloads/ogb/ogb_jax/workload.py b/workloads/ogb/ogb_jax/workload.py index dbe39c2cd..add36c831 100644 --- a/workloads/ogb/ogb_jax/workload.py +++ b/workloads/ogb/ogb_jax/workload.py @@ -239,4 +239,6 @@ def eval_model( params, graphs, labels, masks, model_state, model_rng) total_metrics = (batch_metrics if total_metrics is None else total_metrics.merge(batch_metrics)) + if total_metrics is None: + return {} return {k: float(v) for k, v in total_metrics.reduce().compute().items()} From 9fe6bdd401d441056792dc7478b2a6e250046f7b Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Feb 2022 19:19:54 -0500 Subject: [PATCH 17/46] fixing jax install on GPUs --- setup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 27466dc1a..2750e171e 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,6 @@ jax_core_deps = [ - 'jax==0.2.17', 'flax==0.3.5', 'optax==0.0.9', 'tensorflow_datasets==4.4.0', @@ -36,10 +35,10 @@ 'scikit-learn==1.0.1', ], extras_require={ - 'jax-cpu': jax_core_deps + ['jaxlib==0.1.71'], + 'jax-cpu': jax_core_deps + ['jax==0.2.28', 'jaxlib==0.1.76'], # Note for GPU support the installer must be run with # `-f 'https://storage.googleapis.com/jax-releases/jax_releases.html'`. - 'jax-gpu': jax_core_deps + ['jaxlib==0.1.71+cuda111'], + 'jax-gpu': jax_core_deps + ['jax[cuda]==0.2.28', 'jaxlib==0.1.76+cuda111.cudnn82'], 'pytorch': [ 'torch==1.9.1+cu111', 'torchvision==0.10.1+cu111', From a431012d69686f252d8c28f23f2d4a68e6784d27 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 7 Feb 2022 17:10:27 -0500 Subject: [PATCH 18/46] moving ogb dir --- algorithmic_efficiency/submission_runner.py | 4 ++-- .../workloads}/ogb/__init__.py | 0 .../workloads}/ogb/ogb_jax/README.md | 0 .../workloads}/ogb/ogb_jax/__init__.py | 0 .../workloads}/ogb/ogb_jax/input_pipeline.py | 0 .../workloads}/ogb/ogb_jax/metrics.py | 0 .../workloads}/ogb/ogb_jax/models.py | 0 .../workloads}/ogb/ogb_jax/submission.py | 0 .../workloads}/ogb/ogb_jax/tuning_search_space.json | 0 .../workloads}/ogb/ogb_jax/workload.py | 0 .../workloads}/ogb/workload.py | 0 11 files changed, 2 insertions(+), 2 deletions(-) rename {workloads => algorithmic_efficiency/workloads}/ogb/__init__.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/README.md (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/__init__.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/input_pipeline.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/metrics.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/models.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/submission.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/tuning_search_space.json (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/ogb_jax/workload.py (100%) rename {workloads => algorithmic_efficiency/workloads}/ogb/workload.py (100%) diff --git a/algorithmic_efficiency/submission_runner.py b/algorithmic_efficiency/submission_runner.py index 874a36030..aa386708d 100644 --- a/algorithmic_efficiency/submission_runner.py +++ b/algorithmic_efficiency/submission_runner.py @@ -49,8 +49,8 @@ 'workload_class_name': 'ImagenetWorkload' }, 'ogb_jax': { - 'workload_path': 'workloads/ogb/ogb_jax/workload.py', - 'workload_class_name': 'OGBWorkload' + 'workload_path': BASE_WORKLOADS_DIR + 'ogb/ogb_jax/workload.py', + 'workload_class_name': 'OGBWorkload' }, 'wmt_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'wmt/wmt_jax/workload.py', diff --git a/workloads/ogb/__init__.py b/algorithmic_efficiency/workloads/ogb/__init__.py similarity index 100% rename from workloads/ogb/__init__.py rename to algorithmic_efficiency/workloads/ogb/__init__.py diff --git a/workloads/ogb/ogb_jax/README.md b/algorithmic_efficiency/workloads/ogb/ogb_jax/README.md similarity index 100% rename from workloads/ogb/ogb_jax/README.md rename to algorithmic_efficiency/workloads/ogb/ogb_jax/README.md diff --git a/workloads/ogb/ogb_jax/__init__.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/__init__.py similarity index 100% rename from workloads/ogb/ogb_jax/__init__.py rename to algorithmic_efficiency/workloads/ogb/ogb_jax/__init__.py diff --git a/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py similarity index 100% rename from workloads/ogb/ogb_jax/input_pipeline.py rename to algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py diff --git a/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py similarity index 100% rename from workloads/ogb/ogb_jax/metrics.py rename to algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py diff --git a/workloads/ogb/ogb_jax/models.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py similarity index 100% rename from workloads/ogb/ogb_jax/models.py rename to algorithmic_efficiency/workloads/ogb/ogb_jax/models.py diff --git a/workloads/ogb/ogb_jax/submission.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/submission.py similarity index 100% rename from workloads/ogb/ogb_jax/submission.py rename to algorithmic_efficiency/workloads/ogb/ogb_jax/submission.py diff --git a/workloads/ogb/ogb_jax/tuning_search_space.json b/algorithmic_efficiency/workloads/ogb/ogb_jax/tuning_search_space.json similarity index 100% rename from workloads/ogb/ogb_jax/tuning_search_space.json rename to algorithmic_efficiency/workloads/ogb/ogb_jax/tuning_search_space.json diff --git a/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py similarity index 100% rename from workloads/ogb/ogb_jax/workload.py rename to algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py diff --git a/workloads/ogb/workload.py b/algorithmic_efficiency/workloads/ogb/workload.py similarity index 100% rename from workloads/ogb/workload.py rename to algorithmic_efficiency/workloads/ogb/workload.py From 1f4164a7814682bb3362229c283d8a6adbec0ecb Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 7 Feb 2022 17:21:38 -0500 Subject: [PATCH 19/46] adding tb profiling --- algorithmic_efficiency/submission_runner.py | 92 +++++++++++---------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/algorithmic_efficiency/submission_runner.py b/algorithmic_efficiency/submission_runner.py index aa386708d..959e7c96c 100644 --- a/algorithmic_efficiency/submission_runner.py +++ b/algorithmic_efficiency/submission_runner.py @@ -12,6 +12,7 @@ """ import importlib import inspect +import jax import json import os import struct @@ -173,51 +174,54 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, training_complete = False global_start_time = time.time() + jax.profiler.start_trace("/tmp/tensorboard") logging.info('Starting training loop.') - while (is_time_remaining and not goal_reached and not training_complete): - step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - start_time = time.time() - selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( - workload, - input_queue, - optimizer_state, - model_params, - hyperparameters, - global_step, - data_select_rng) - try: - optimizer_state, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types(), - model_state=model_state, - hyperparameters=hyperparameters, - input_batch=selected_train_input_batch, - label_batch=selected_train_label_batch, - mask_batch=selected_train_mask_batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=update_rng) - except spec.TrainingCompleteError: - training_complete = True - global_step += 1 - current_time = time.time() - accumulated_submission_time += current_time - start_time - is_time_remaining = ( - accumulated_submission_time < workload.max_allowed_runtime_sec) - # Check if submission is eligible for an untimed eval. - if (current_time - last_eval_time >= workload.eval_period_time_sec or - training_complete): - latest_eval_result = workload.eval_model(model_params, model_state, - eval_rng, data_dir) - logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' - f'\t{latest_eval_result}') - last_eval_time = current_time - eval_results.append((global_step, latest_eval_result)) - goal_reached = workload.has_reached_goal(latest_eval_result) + for _ in range(10): + with jax.profiler.StepTraceAnnotation("train", step_num=global_step): + step_rng = prng.fold_in(rng, global_step) + data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) + start_time = time.time() + selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( + workload, + input_queue, + optimizer_state, + model_params, + hyperparameters, + global_step, + data_select_rng) + try: + optimizer_state, model_params, model_state = update_params( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types(), + model_state=model_state, + hyperparameters=hyperparameters, + input_batch=selected_train_input_batch, + label_batch=selected_train_label_batch, + mask_batch=selected_train_mask_batch, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=update_rng) + except spec.TrainingCompleteError: + training_complete = True + global_step += 1 + current_time = time.time() + accumulated_submission_time += current_time - start_time + is_time_remaining = ( + accumulated_submission_time < workload.max_allowed_runtime_sec) + # Check if submission is eligible for an untimed eval. + if (current_time - last_eval_time >= workload.eval_period_time_sec or + training_complete): + latest_eval_result = workload.eval_model(model_params, model_state, + eval_rng, data_dir) + logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' + f'\t{latest_eval_result}') + last_eval_time = current_time + eval_results.append((global_step, latest_eval_result)) + goal_reached = workload.has_reached_goal(latest_eval_result) + jax.profiler.stop_trace() metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics From 5d1ebc18dba4c03c965db0786ec61037a3b04e89 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 7 Feb 2022 17:46:23 -0500 Subject: [PATCH 20/46] adding trace event, flax profiling, moving submission to baselines/ --- algorithmic_efficiency/submission_runner.py | 20 ++- .../workloads/ogb/ogb_jax/submission.py | 116 ------------------ .../ogb/ogb_jax/tuning_search_space.json | 1 - 3 files changed, 7 insertions(+), 130 deletions(-) delete mode 100644 algorithmic_efficiency/workloads/ogb/ogb_jax/submission.py delete mode 100644 algorithmic_efficiency/workloads/ogb/ogb_jax/tuning_search_space.json diff --git a/algorithmic_efficiency/submission_runner.py b/algorithmic_efficiency/submission_runner.py index 959e7c96c..9576b3d38 100644 --- a/algorithmic_efficiency/submission_runner.py +++ b/algorithmic_efficiency/submission_runner.py @@ -15,6 +15,8 @@ import jax import json import os +# Enable flax xprof trace labelling. +os.environ['FLAX_PROFILE'] = 'true' import struct import time from typing import Optional, Tuple @@ -208,19 +210,11 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, training_complete = True global_step += 1 current_time = time.time() - accumulated_submission_time += current_time - start_time - is_time_remaining = ( - accumulated_submission_time < workload.max_allowed_runtime_sec) - # Check if submission is eligible for an untimed eval. - if (current_time - last_eval_time >= workload.eval_period_time_sec or - training_complete): - latest_eval_result = workload.eval_model(model_params, model_state, - eval_rng, data_dir) - logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' - f'\t{latest_eval_result}') - last_eval_time = current_time - eval_results.append((global_step, latest_eval_result)) - goal_reached = workload.has_reached_goal(latest_eval_result) + latest_eval_result = workload.eval_model(model_params, model_state, + eval_rng, data_dir) + logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' + f'\t{latest_eval_result}') + eval_results.append((global_step, latest_eval_result)) jax.profiler.stop_trace() metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/submission.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/submission.py deleted file mode 100644 index 62caabf93..000000000 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/submission.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Iterator, List, Optional, Tuple - -import functools -import numpy as np -import jax -import jax.numpy as jnp -from jax import lax -from flax import jax_utils -import jraph -import optax - -import spec - - -def get_batch_size(workload_name): - batch_sizes = {'ogb_jax': 256} - return batch_sizes[workload_name] - - -def optimizer(hyperparameters: spec.Hyperparamters) -> optax.GradientTransformation: - """Creates an optimizer.""" - opt_init_fn, opt_update_fn = optax.adam( - learning_rate=hyperparameters.learning_rate) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state( - workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparamters, - rng: spec.RandomState) -> spec.OptimizerState: - params_zeros_like = jax.tree_map( - lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters) - init_optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(init_optimizer_state), opt_update_fn - - -# We need to jax.pmap here instead of inside update_params because the latter -# would recompile the function every step. -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, input_batch, - label_batch, mask_batch, rng): - del hyperparameters - - def loss_fn(params): - logits_batch, new_model_state = workload.model_fn( - params, - input_batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss = workload.loss_fn(label_batch, logits_batch, mask_batch) - mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) - return mean_loss, new_model_state - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, new_model_state), grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparamters, - input_batch: spec.Tensor, - label_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor], - loss_type: spec.LossType, - # This will define the output activation via `output_activation_fn`. - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, input_batch, label_batch, - mask_batch, rng) - - #steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') - #if (global_step + 1) % steps_per_epoch == 0: - # # sync batch statistics across replicas once per epoch - # new_model_state = workload.sync_batch_stats(new_model_state) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection( - workload: spec.Workload, - input_queue: Iterator[Tuple[spec.Tensor, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - hyperparameters: spec.Hyperparamters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue.""" - return next(input_queue) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/tuning_search_space.json b/algorithmic_efficiency/workloads/ogb/ogb_jax/tuning_search_space.json deleted file mode 100644 index 7aba31610..000000000 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/tuning_search_space.json +++ /dev/null @@ -1 +0,0 @@ -{"learning_rate": {"feasible_points": [1e-3]}} \ No newline at end of file From ca10de7d8ba5eeb8ea8fcca7e85410c265e0564e Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 7 Feb 2022 18:06:31 -0500 Subject: [PATCH 21/46] adding baselines/ogb and __init__ --- baselines/__init__.py | 0 baselines/ogb/ogb_jax/submission.py | 117 ++++++++++++++++++ .../ogb/ogb_jax/tuning_search_space.json | 1 + 3 files changed, 118 insertions(+) create mode 100644 baselines/__init__.py create mode 100644 baselines/ogb/ogb_jax/submission.py create mode 100644 baselines/ogb/ogb_jax/tuning_search_space.json diff --git a/baselines/__init__.py b/baselines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogb/ogb_jax/submission.py new file mode 100644 index 000000000..ce7380613 --- /dev/null +++ b/baselines/ogb/ogb_jax/submission.py @@ -0,0 +1,117 @@ +from typing import Iterator, List, Optional, Tuple + +import functools +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from flax import jax_utils +import jraph +import optax + +import spec + + +def get_batch_size(workload_name): + batch_sizes = {'ogb_jax': 256} + return batch_sizes[workload_name] + + +def optimizer(hyperparameters: spec.Hyperparamters) -> optax.GradientTransformation: + """Creates an optimizer.""" + opt_init_fn, opt_update_fn = optax.adam( + learning_rate=hyperparameters.learning_rate) + return opt_init_fn, opt_update_fn + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparamters, + rng: spec.RandomState) -> spec.OptimizerState: + params_zeros_like = jax.tree_map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) + opt_init_fn, opt_update_fn = optimizer(hyperparameters) + init_optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(init_optimizer_state), opt_update_fn + + +# We need to jax.pmap here instead of inside update_params because the latter +# would recompile the function every step. +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0, 0, None), + static_broadcasted_argnums=(0, 1)) +def pmapped_train_step(workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, input_batch, + label_batch, mask_batch, rng): + del hyperparameters + + def loss_fn(params): + logits_batch, new_model_state = workload.model_fn( + params, + input_batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = workload.loss_fn(label_batch, logits_batch, mask_batch) + mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) + return mean_loss, new_model_state + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, new_model_state), grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_model_state, new_optimizer_state, updated_params + + +@jax.profiler.annotate_function +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparamters, + input_batch: spec.Tensor, + label_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor], + loss_type: spec.LossType, + # This will define the output activation via `output_activation_fn`. + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + optimizer_state, opt_update_fn = optimizer_state + new_model_state, new_optimizer_state, new_params = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, input_batch, label_batch, + mask_batch, rng) + + #steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') + #if (global_step + 1) % steps_per_epoch == 0: + # # sync batch statistics across replicas once per epoch + # new_model_state = workload.sync_batch_stats(new_model_state) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Tuple[spec.Tensor, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparamters, + global_step: int, + rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue.""" + return next(input_queue) diff --git a/baselines/ogb/ogb_jax/tuning_search_space.json b/baselines/ogb/ogb_jax/tuning_search_space.json new file mode 100644 index 000000000..7aba31610 --- /dev/null +++ b/baselines/ogb/ogb_jax/tuning_search_space.json @@ -0,0 +1 @@ +{"learning_rate": {"feasible_points": [1e-3]}} \ No newline at end of file From 0bf0693f79e25b82fceb50c96b77ac7efea09c92 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 9 Feb 2022 16:45:01 -0500 Subject: [PATCH 22/46] fixing file locations and imports after refactor --- .../workloads/ogb/ogb_jax/workload.py | 10 +++++----- algorithmic_efficiency/workloads/ogb/workload.py | 2 +- baselines/ogb/ogb_jax/submission.py | 2 +- .../submission_runner.py => submission_runner.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) rename algorithmic_efficiency/submission_runner.py => submission_runner.py (99%) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index add36c831..a825d5bb7 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -12,11 +12,11 @@ from flax import linen as nn from flax import jax_utils -import spec -from workloads.ogb.workload import OGB -from workloads.ogb.ogb_jax import input_pipeline -from workloads.ogb.ogb_jax import models -from workloads.ogb.ogb_jax import metrics +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogb.workload import OGB +from algorithmic_efficiency.workloads.ogb.ogb_jax import input_pipeline +from algorithmic_efficiency.workloads.ogb.ogb_jax import models +from algorithmic_efficiency.workloads.ogb.ogb_jax import metrics class OGBWorkload(OGB): diff --git a/algorithmic_efficiency/workloads/ogb/workload.py b/algorithmic_efficiency/workloads/ogb/workload.py index 491ebcf86..ca8d36c64 100644 --- a/algorithmic_efficiency/workloads/ogb/workload.py +++ b/algorithmic_efficiency/workloads/ogb/workload.py @@ -1,4 +1,4 @@ -import spec +from algorithmic_efficiency import spec class OGB(spec.Workload): diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogb/ogb_jax/submission.py index ce7380613..a9044994f 100644 --- a/baselines/ogb/ogb_jax/submission.py +++ b/baselines/ogb/ogb_jax/submission.py @@ -9,7 +9,7 @@ import jraph import optax -import spec +from algorithmic_efficiency import spec def get_batch_size(workload_name): diff --git a/algorithmic_efficiency/submission_runner.py b/submission_runner.py similarity index 99% rename from algorithmic_efficiency/submission_runner.py rename to submission_runner.py index 9576b3d38..2ce5b5772 100644 --- a/algorithmic_efficiency/submission_runner.py +++ b/submission_runner.py @@ -27,7 +27,7 @@ from algorithmic_efficiency import halton from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng +from algorithmic_efficiency import random_utils as prng # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = "algorithmic_efficiency/workloads/" From d43667fa05e928018113d9f9b7447002e99322d4 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 9 Feb 2022 17:12:20 -0500 Subject: [PATCH 23/46] fixing val ds iter --- .../workloads/ogb/ogb_jax/workload.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index a825d5bb7..33443551d 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -3,14 +3,14 @@ from typing import Optional, Tuple import functools import numpy as np -import sklearn.metrics - +from flax import linen as nn +from flax import jax_utils +import itertools import jax import jax.numpy as jnp -import random_utils as prng import jraph -from flax import linen as nn -from flax import jax_utils +import random_utils as prng +import sklearn.metrics from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.ogb.workload import OGB @@ -231,9 +231,13 @@ def eval_model( if self._eval_iterator is None: self._eval_iterator = self._build_iterator( data_rng, 'validation', data_dir, batch_size=eval_batch_size) + # Note that this effectively stores the entire val dataset in memory. + self._eval_iterator = itertools.cycle(self._eval_iterator) total_metrics = None - # Loop over graph batches in eval dataset + # Loop over graph batches in eval dataset. + num_val_examples = 43793 # Both val and test have the same number. + num_val_steps = num_val_examples // eval_batch_size for graphs, labels, masks in self._eval_iterator: batch_metrics = self._eval_batch( params, graphs, labels, masks, model_state, model_rng) From e643a5dbc6196bd01701e3f7d21a9b2397c89a9d Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 9 Feb 2022 17:17:45 -0500 Subject: [PATCH 24/46] fixing val iter --- .../workloads/ogb/ogb_jax/workload.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 33443551d..5e2f3e888 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -9,9 +9,9 @@ import jax import jax.numpy as jnp import jraph -import random_utils as prng import sklearn.metrics +from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.ogb.workload import OGB from algorithmic_efficiency.workloads.ogb.ogb_jax import input_pipeline @@ -235,10 +235,12 @@ def eval_model( self._eval_iterator = itertools.cycle(self._eval_iterator) total_metrics = None + # Both val and test have the same (prime) number of examples. + num_val_examples = 43793 + num_val_steps = num_val_examples // eval_batch_size + 1 # Loop over graph batches in eval dataset. - num_val_examples = 43793 # Both val and test have the same number. - num_val_steps = num_val_examples // eval_batch_size - for graphs, labels, masks in self._eval_iterator: + for _ in range(num_val_steps): + graphs, labels, masks = next(self._eval_iterator) batch_metrics = self._eval_batch( params, graphs, labels, masks, model_state, model_rng) total_metrics = (batch_metrics if total_metrics is None From 0ba6c09880a380ad1fd1f0fc4815501a54379917 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 14 Feb 2022 18:17:10 -0500 Subject: [PATCH 25/46] attempting to simplify the input pipeline --- .../workloads/ogb/ogb_jax/input_pipeline.py | 214 +++++------------- .../workloads/ogb/ogb_jax/workload.py | 9 +- submission_runner.py | 2 +- 3 files changed, 62 insertions(+), 163 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 27601afe7..1497f3ebe 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -4,7 +4,7 @@ """Exposes the ogbg-molpcba dataset in a convenient format.""" import functools -from typing import Dict, NamedTuple +from typing import Any, Dict, NamedTuple import jax import jraph import numpy as np @@ -15,6 +15,10 @@ import tensorflow_datasets as tfds +AVG_NODES_PER_GRAPH = 26 +AVG_EDGES_PER_GRAPH = 56 + + class GraphsTupleSize(NamedTuple): """Helper class to represent padding and graph sizes.""" n_node: int @@ -22,15 +26,15 @@ class GraphsTupleSize(NamedTuple): n_graph: int -def get_raw_datasets() -> Dict[str, tf.data.Dataset]: +def get_raw_dataset( + split_name: str, + data_dir: str, + file_shuffle_seed: Any) -> Dict[str, tf.data.Dataset]: """Returns datasets as tf.data.Dataset, organized by split.""" - ds_builder = tfds.builder('ogbg_molpcba') + ds_builder = tfds.builder('ogbg_molpcba', data_dir=data_dir) ds_builder.download_and_prepare() - ds_splits = ['train', 'validation', 'test'] - datasets = { - split: ds_builder.as_dataset(split=split) for split in ds_splits - } - return datasets + config = tfds.ReadConfig(shuffle_seed=file_shuffle_seed) + return ds_builder.as_dataset(split=split_name, read_config=config) def _get_valid_mask(graphs: jraph.GraphsTuple): @@ -73,16 +77,23 @@ def _batch_for_pmap(iterator): count = 0 -def get_dataset_iters(batch_size: int, - add_virtual_node: bool = True, - add_undirected_edges: bool = True, - add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: +def get_dataset_iter(split_name: str, + data_rng: jax.random.PRNGKey, + data_dir: str, + batch_size: int, + add_virtual_node: bool = True, + add_undirected_edges: bool = True, + add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: """Returns datasets of batched GraphsTuples, organized by split.""" if batch_size <= 1: raise ValueError('Batch size must be > 1 to account for padding graphs.') + file_shuffle_seed, dataset_shuffle_seed = jax.random.split(data_rng) + file_shuffle_seed = file_shuffle_seed[0] + dataset_shuffle_seed = dataset_shuffle_seed[0] + # Obtain the original datasets. - datasets = get_raw_datasets() + dataset = get_raw_dataset(split_name, data_dir, file_shuffle_seed) # Construct the GraphsTuple converter function. convert_to_graphs_tuple_fn = functools.partial( @@ -92,51 +103,39 @@ def get_dataset_iters(batch_size: int, add_self_loops=add_virtual_node, ) - # Process each split separately. - for split_name in datasets: - # Convert to GraphsTuple. - datasets[split_name] = datasets[split_name].map( - convert_to_graphs_tuple_fn, - num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) - - # Compute the padding budget for the requested batch size. - budget = estimate_padding_budget_for_batch_size(datasets['train'], batch_size, - num_estimation_graphs=100) - - # Pad an example graph to see what the output shapes will be. - # We will use this shape information when creating the tf.data.Dataset. - example_graph = next(datasets['train'].as_numpy_iterator()) - example_padded_graph = jraph.pad_with_graphs(example_graph, *budget) - padded_graphs_spec = specs_from_graphs_tuple(example_padded_graph) - - # Process each split separately. - for split_name, dataset_split in datasets.items(): - - # Repeat and shuffle the training split. - if split_name == 'train': - dataset_split = dataset_split.shuffle(100, reshuffle_each_iteration=True) - dataset_split = dataset_split.repeat() - # We cache the validation and test sets, since these are small. - else: - dataset_split = dataset_split.cache() - - # Batch and pad each split. Note that this also converts the graphs to - # numpy. - batched_iter = jraph.dynamically_batch( - graphs_tuple_iterator=iter(dataset_split), - n_node=budget.n_node, - n_edge=budget.n_edge, - n_graph=budget.n_graph) + dataset = dataset.map( + convert_to_graphs_tuple_fn, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=True) + + # Repeat and shuffle the training split. + if split_name == 'train': + dataset = dataset.shuffle( + buffer_size=2**15, + seed=dataset_shuffle_seed, + reshuffle_each_iteration=True) + dataset = dataset.repeat() + # We cache the validation and test sets, since these are small. + else: + dataset = dataset.cache() + + # Batch and pad each split. Note that this also converts the graphs to + # numpy. + max_n_nodes = AVG_NODES_PER_GRAPH * batch_size + max_n_edges = AVG_EDGES_PER_GRAPH * batch_size + batched_iter = jraph.dynamically_batch( + graphs_tuple_iterator=iter(dataset), + n_node=max_n_nodes, + n_edge=max_n_edges, + n_graph=batch_size) - # An iterator of Tuple[graph, labels, mask]. - masked_iter = map(_get_valid_mask, batched_iter) + # An iterator of Tuple[graph, labels, mask]. + masked_iter = map(_get_valid_mask, batched_iter) - # An iterator the same as above, but where each element has an extra leading - # dim of size jax.local_device_count(). - pmapped_iterator = _batch_for_pmap(masked_iter) - datasets[split_name] = pmapped_iterator - return datasets + # An iterator the same as above, but where each element has an extra leading + # dim of size jax.local_device_count(). + pmapped_iterator = _batch_for_pmap(masked_iter) + return pmapped_iterator def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], @@ -148,43 +147,10 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], num_edges = tf.squeeze(graph['num_edges']) nodes = graph['node_feat'] edges = graph['edge_feat'] - edge_feature_dim = edges.shape[-1] labels = graph['labels'] senders = graph['edge_index'][:, 0] receivers = graph['edge_index'][:, 1] - # # Add a virtual node connected to all other nodes. - # # The feature vectors for the virtual node - # # and the new edges are set to all zeros. - # if add_virtual_node: - # nodes = tf.concat( - # [nodes, tf.zeros_like(nodes[0, None])], axis=0) - # senders = tf.concat( - # [senders, tf.range(num_nodes)], axis=0) - # receivers = tf.concat( - # [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) - # edges = tf.concat( - # [edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) - # num_edges += num_nodes - # num_nodes += 1 - - # # Make edges undirected, by adding edges with senders and receivers flipped. - # # The feature vector for the flipped edge is the same as the original edge. - # if add_undirected_edges: - # new_senders = tf.concat([senders, receivers], axis=0) - # new_receivers = tf.concat([receivers, senders], axis=0) - # edges = tf.concat([edges, edges], axis=0) - # senders, receivers = new_senders, new_receivers - # num_edges *= 2 - - # # Add self-loops for each node. - # # The feature vectors for the self-loops are set to all zeros. - # if add_self_loops: - # senders = tf.concat([senders, tf.range(num_nodes)], axis=0) - # receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) - # edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) - # num_edges += num_nodes - return jraph.GraphsTuple( n_node=tf.expand_dims(num_nodes, 0), n_edge=tf.expand_dims(num_edges, 0), @@ -194,73 +160,3 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], receivers=receivers, globals=tf.expand_dims(labels, axis=0), ) - - -def estimate_padding_budget_for_batch_size( - dataset: tf.data.Dataset, - batch_size: int, - num_estimation_graphs: int) -> GraphsTupleSize: - """Estimates the padding budget for a dataset of unbatched GraphsTuples. - - Args: - dataset: A dataset of unbatched GraphsTuples. - batch_size: The intended batch size. Note that no batching is performed by - this function. - num_estimation_graphs: How many graphs to take from the dataset to estimate - the distribution of number of nodes and edges per graph. - - Returns: - padding_budget: The padding budget for batching and padding the graphs - in this dataset to the given batch size. - """ - - def next_multiple_of_64(val: float): - """Returns the next multiple of 64 after val.""" - return 64 * (1 + int(val // 64)) - - if batch_size <= 1: - raise ValueError('Batch size must be > 1 to account for padding graphs.') - - total_num_nodes = 0 - total_num_edges = 0 - for graph in dataset.take(num_estimation_graphs).as_numpy_iterator(): - graph_size = get_graphs_tuple_size(graph) - if graph_size.n_graph != 1: - raise ValueError('Dataset contains batched GraphTuples.') - - total_num_nodes += graph_size.n_node - total_num_edges += graph_size.n_edge - - num_nodes_per_graph_estimate = total_num_nodes / num_estimation_graphs - num_edges_per_graph_estimate = total_num_edges / num_estimation_graphs - - padding_budget = GraphsTupleSize( - n_node=next_multiple_of_64(num_nodes_per_graph_estimate * batch_size), - n_edge=next_multiple_of_64(num_edges_per_graph_estimate * batch_size), - n_graph=batch_size) - return padding_budget - - -def specs_from_graphs_tuple(graph: jraph.GraphsTuple): - """Returns a tf.TensorSpec corresponding to this graph.""" - - def get_tensor_spec(array: np.ndarray): - shape = list(array.shape) - dtype = array.dtype - return tf.TensorSpec(shape=shape, dtype=dtype) - - specs = {} - for field in [ - 'nodes', 'edges', 'senders', 'receivers', 'globals', 'n_node', 'n_edge' - ]: - field_sample = getattr(graph, field) - specs[field] = get_tensor_spec(field_sample) - return jraph.GraphsTuple(**specs) - - -def get_graphs_tuple_size(graph: jraph.GraphsTuple): - """Returns the number of nodes, edges and graphs in a GraphsTuple.""" - return GraphsTupleSize( - n_node=np.sum(graph.n_node), - n_edge=np.sum(graph.n_edge), - n_graph=np.shape(graph.n_node)[0]) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 5e2f3e888..176b78d8f 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -40,16 +40,19 @@ def _build_iterator( split: str, data_dir: str, batch_size: int): - dataset_iters = input_pipeline.get_dataset_iters( + dataset_iter = input_pipeline.get_dataset_iter( + split, + data_rng, + data_dir, batch_size, add_virtual_node=False, add_undirected_edges=True, add_self_loops=True) if self._init_graphs is None: - init_graphs = next(dataset_iters['train'])[0] + init_graphs = next(dataset_iter)[0] # Unreplicate the iterator that has the leading dim for pmapping. self._init_graphs = jax.tree_map(lambda x: x[0], init_graphs) - return dataset_iters[split] + return dataset_iter def build_input_queue( self, diff --git a/submission_runner.py b/submission_runner.py index 2ce5b5772..d9392c18c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -87,7 +87,7 @@ 'The path to the JSON file describing the external tuning search space.') flags.DEFINE_integer('num_tuning_trials', 20, 'The number of external hyperparameter trials to run.') -flags.DEFINE_string('data_dir', '~/', 'Dataset location') +flags.DEFINE_string('data_dir', '~/tensorflow_datasets/', 'Dataset location') flags.DEFINE_enum( 'framework', None, From 402309d2ec1d6779aa784b0519df627923ee46c8 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 14 Feb 2022 20:09:21 -0500 Subject: [PATCH 26/46] simplifying input pipeline, adding more logging --- .../workloads/ogb/ogb_jax/input_pipeline.py | 14 +++++--------- submission_runner.py | 17 ++++++++++------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 1497f3ebe..919f8d7b8 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -62,6 +62,7 @@ def _batch_for_pmap(iterator): masks = [] count = 0 for graph_batch, label_batch, mask_batch in iterator: + graph_batch = _get_valid_mask(graph_batch) count += 1 graphs.append(graph_batch) labels.append(label_batch) @@ -115,12 +116,10 @@ def get_dataset_iter(split_name: str, seed=dataset_shuffle_seed, reshuffle_each_iteration=True) dataset = dataset.repeat() - # We cache the validation and test sets, since these are small. - else: - dataset = dataset.cache() + # We do not need to cache the validation and test sets because we do this + # later with itertools.cycle. - # Batch and pad each split. Note that this also converts the graphs to - # numpy. + # Batch and pad each split. Note that this also converts the graphs to numpy. max_n_nodes = AVG_NODES_PER_GRAPH * batch_size max_n_edges = AVG_EDGES_PER_GRAPH * batch_size batched_iter = jraph.dynamically_batch( @@ -129,12 +128,9 @@ def get_dataset_iter(split_name: str, n_edge=max_n_edges, n_graph=batch_size) - # An iterator of Tuple[graph, labels, mask]. - masked_iter = map(_get_valid_mask, batched_iter) - # An iterator the same as above, but where each element has an extra leading # dim of size jax.local_device_count(). - pmapped_iterator = _batch_for_pmap(masked_iter) + pmapped_iterator = _batch_for_pmap(batched_iter) return pmapped_iterator diff --git a/submission_runner.py b/submission_runner.py index d9392c18c..3df88dc73 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -182,7 +182,8 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, with jax.profiler.StepTraceAnnotation("train", step_num=global_step): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - start_time = time.time() + # start_time = time.time() + logging.info(f'starting step {global_step}') selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( workload, input_queue, @@ -191,6 +192,7 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, hyperparameters, global_step, data_select_rng) + logging.info(f'starting update {global_step}') try: optimizer_state, model_params, model_state = update_params( workload=workload, @@ -208,13 +210,14 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, rng=update_rng) except spec.TrainingCompleteError: training_complete = True + logging.info(f'finished step {global_step}') global_step += 1 - current_time = time.time() - latest_eval_result = workload.eval_model(model_params, model_state, - eval_rng, data_dir) - logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' - f'\t{latest_eval_result}') - eval_results.append((global_step, latest_eval_result)) + # current_time = time.time() + # latest_eval_result = workload.eval_model(model_params, model_state, + # eval_rng, data_dir) + # logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' + # f'\t{latest_eval_result}') + # eval_results.append((global_step, latest_eval_result)) jax.profiler.stop_trace() metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics From 2a50150bf4553e6f14bd8486ed7bd61139b372ff Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 14 Feb 2022 20:12:05 -0500 Subject: [PATCH 27/46] fix --- .../workloads/ogb/ogb_jax/input_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 919f8d7b8..daa3c0882 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -61,8 +61,8 @@ def _batch_for_pmap(iterator): labels = [] masks = [] count = 0 - for graph_batch, label_batch, mask_batch in iterator: - graph_batch = _get_valid_mask(graph_batch) + for graphs in iterator: + graph_batch, label_batch, mask_batch = _get_valid_mask(graphs) count += 1 graphs.append(graph_batch) labels.append(label_batch) From 7ca778fae45491cbe118c56f62c7346f22deadad Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 14 Feb 2022 20:15:16 -0500 Subject: [PATCH 28/46] fix pls --- .../workloads/ogb/ogb_jax/input_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index daa3c0882..ca51ee7b7 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -61,8 +61,8 @@ def _batch_for_pmap(iterator): labels = [] masks = [] count = 0 - for graphs in iterator: - graph_batch, label_batch, mask_batch = _get_valid_mask(graphs) + for batch in iterator: + graph_batch, label_batch, mask_batch = _get_valid_mask(batch) count += 1 graphs.append(graph_batch) labels.append(label_batch) From 321364b2446571a2e82bfb0aaaf985c199611163 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 14 Feb 2022 20:39:33 -0500 Subject: [PATCH 29/46] proper num eval steps --- .../workloads/ogb/ogb_jax/input_pipeline.py | 52 +++++++++---------- .../workloads/ogb/ogb_jax/workload.py | 10 ++-- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index ca51ee7b7..67153badb 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -26,7 +26,7 @@ class GraphsTupleSize(NamedTuple): n_graph: int -def get_raw_dataset( +def _get_raw_dataset( split_name: str, data_dir: str, file_shuffle_seed: Any) -> Dict[str, tf.data.Dataset]: @@ -37,6 +37,30 @@ def get_raw_dataset( return ds_builder.as_dataset(split=split_name, read_config=config) +def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], + add_virtual_node: bool, + add_undirected_edges: bool, + add_self_loops: bool) -> jraph.GraphsTuple: + """Converts a dictionary of tf.Tensors to a GraphsTuple.""" + num_nodes = tf.squeeze(graph['num_nodes']) + num_edges = tf.squeeze(graph['num_edges']) + nodes = graph['node_feat'] + edges = graph['edge_feat'] + labels = graph['labels'] + senders = graph['edge_index'][:, 0] + receivers = graph['edge_index'][:, 1] + + return jraph.GraphsTuple( + n_node=tf.expand_dims(num_nodes, 0), + n_edge=tf.expand_dims(num_edges, 0), + nodes=nodes, + edges=edges, + senders=senders, + receivers=receivers, + globals=tf.expand_dims(labels, axis=0), + ) + + def _get_valid_mask(graphs: jraph.GraphsTuple): """Gets the binary mask indicating only valid labels and graphs.""" labels = graphs.globals @@ -94,7 +118,7 @@ def get_dataset_iter(split_name: str, dataset_shuffle_seed = dataset_shuffle_seed[0] # Obtain the original datasets. - dataset = get_raw_dataset(split_name, data_dir, file_shuffle_seed) + dataset = _get_raw_dataset(split_name, data_dir, file_shuffle_seed) # Construct the GraphsTuple converter function. convert_to_graphs_tuple_fn = functools.partial( @@ -132,27 +156,3 @@ def get_dataset_iter(split_name: str, # dim of size jax.local_device_count(). pmapped_iterator = _batch_for_pmap(batched_iter) return pmapped_iterator - - -def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], - add_virtual_node: bool, - add_undirected_edges: bool, - add_self_loops: bool) -> jraph.GraphsTuple: - """Converts a dictionary of tf.Tensors to a GraphsTuple.""" - num_nodes = tf.squeeze(graph['num_nodes']) - num_edges = tf.squeeze(graph['num_edges']) - nodes = graph['node_feat'] - edges = graph['edge_feat'] - labels = graph['labels'] - senders = graph['edge_index'][:, 0] - receivers = graph['edge_index'][:, 1] - - return jraph.GraphsTuple( - n_node=tf.expand_dims(num_nodes, 0), - n_edge=tf.expand_dims(num_edges, 0), - nodes=nodes, - edges=edges, - senders=senders, - receivers=receivers, - globals=tf.expand_dims(labels, axis=0), - ) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 176b78d8f..81fa06a14 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -1,5 +1,5 @@ """OGB workload implemented in Jax.""" - +from absl import logging from typing import Optional, Tuple import functools import numpy as np @@ -230,7 +230,7 @@ def eval_model( data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) - eval_batch_size = 256 + eval_batch_size = 1024 if self._eval_iterator is None: self._eval_iterator = self._build_iterator( data_rng, 'validation', data_dir, batch_size=eval_batch_size) @@ -240,9 +240,11 @@ def eval_model( total_metrics = None # Both val and test have the same (prime) number of examples. num_val_examples = 43793 - num_val_steps = num_val_examples // eval_batch_size + 1 + num_val_steps = ( + num_val_examples // (eval_batch_size * jax.local_device_count()) + 1) # Loop over graph batches in eval dataset. - for _ in range(num_val_steps): + for s in range(num_val_steps): + logging.info(f'eval step {s}') graphs, labels, masks = next(self._eval_iterator) batch_metrics = self._eval_batch( params, graphs, labels, masks, model_state, model_rng) From f147a890d1ea1bf6a54e5a3946522333f213b456 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 16 Feb 2022 15:57:13 -0500 Subject: [PATCH 30/46] undoing submission runner changes to try to repro on 1xv100 --- .../workloads/ogb/ogb_jax/input_pipeline.py | 3 +- .../workloads/ogb/ogb_jax/workload.py | 32 ++----- setup.py | 53 +---------- submission_runner.py | 87 ++++++++++--------- 4 files changed, 54 insertions(+), 121 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 67153badb..ee6ee5fb1 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -57,8 +57,7 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], edges=edges, senders=senders, receivers=receivers, - globals=tf.expand_dims(labels, axis=0), - ) + globals=tf.expand_dims(labels, axis=0)) def _get_valid_mask(graphs: jraph.GraphsTuple): diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 81fa06a14..e2cd14ff8 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -179,29 +179,6 @@ def loss_fn( mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) return mean_loss - def _compute_mean_average_precision(self, labels, logits, masks): - """Computes the mean average precision (mAP) over different tasks.""" - # Matches the official OGB evaluation scheme for mean average precision. - assert logits.shape == labels.shape == masks.shape - assert len(logits.shape) == 2 - - probs = jax.nn.sigmoid(logits) - num_tasks = labels.shape[1] - average_precisions = np.full(num_tasks, np.nan) - - for task in range(num_tasks): - # AP is only defined when there is at least one negative data - # and at least one positive data. - if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: - is_labeled = masks[:, task] - average_precisions[task] = sklearn.metrics.average_precision_score( - labels[is_labeled, task], probs[is_labeled, task]) - - # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. - if np.isnan(average_precisions).all(): - return np.nan - return np.nanmean(average_precisions) - def _eval_metric(self, labels, logits, masks): loss = self.loss_fn(labels, logits, masks) return metrics.EvalMetrics.single_from_model_output( @@ -230,18 +207,19 @@ def eval_model( data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) - eval_batch_size = 1024 + eval_per_core_batch_size = 1024 if self._eval_iterator is None: self._eval_iterator = self._build_iterator( - data_rng, 'validation', data_dir, batch_size=eval_batch_size) + data_rng, 'validation', data_dir, batch_size=eval_per_core_batch_size) # Note that this effectively stores the entire val dataset in memory. self._eval_iterator = itertools.cycle(self._eval_iterator) total_metrics = None # Both val and test have the same (prime) number of examples. num_val_examples = 43793 - num_val_steps = ( - num_val_examples // (eval_batch_size * jax.local_device_count()) + 1) + total_eval_batch_size = eval_per_core_batch_size * jax.local_device_count() + # num_val_steps = num_val_examples // total_eval_batch_size + 1 DO NOT SUBMIT + num_val_steps = 1 # Loop over graph batches in eval dataset. for s in range(num_val_steps): logging.info(f'eval step {s}') diff --git a/setup.py b/setup.py index 4fca242ad..26e08e48e 100644 --- a/setup.py +++ b/setup.py @@ -1,56 +1,5 @@ -"""Setup file for algorithmic_efficiency, use setup.cfg for configuration.""" - from setuptools import setup -<<<<<<< HEAD - -jax_core_deps = [ - 'flax==0.3.5', - 'optax==0.0.9', - 'tensorflow_datasets==4.4.0', - 'tensorflow==2.5.0', -] - -# Assumes CUDA 11.x and a compatible NVIDIA driver and CuDNN. -setup( - name='algorithmic_efficiency', - version='0.0.1', - description='MLCommons Algorithmic Efficiency', - author='MLCommons Algorithmic Efficiency Working Group', - author_email='algorithms@mlcommons.org', - url='https://github.com/mlcommons/algorithmic-efficiency', - license='Apache 2.0', - python_requires=">=3.7", - packages=find_packages(), - install_requires=[ - 'absl-py==0.14.0', - 'clu==0.0.6', - 'jraph==0.0.2.dev', - 'numpy>=1.19.2', - 'pandas==1.3.4', - 'scikit-learn==1.0.1', - ], - extras_require={ - 'jax-cpu': jax_core_deps + ['jax==0.2.28', 'jaxlib==0.1.76'], - # Note for GPU support the installer must be run with - # `-f 'https://storage.googleapis.com/jax-releases/jax_releases.html'`. - 'jax-gpu': jax_core_deps + ['jax[cuda]==0.2.28', 'jaxlib==0.1.76+cuda111.cudnn82'], - 'pytorch': [ - 'torch==1.9.1+cu111', - 'torchvision==0.10.1+cu111', - ], - }, - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - ], - keywords='mlcommons algorithmic efficiency', -) -======= if __name__ == "__main__": - setup() ->>>>>>> main + setup() diff --git a/submission_runner.py b/submission_runner.py index 3df88dc73..60d55c4e2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -178,46 +178,53 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, jax.profiler.start_trace("/tmp/tensorboard") logging.info('Starting training loop.') - for _ in range(10): - with jax.profiler.StepTraceAnnotation("train", step_num=global_step): - step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - # start_time = time.time() - logging.info(f'starting step {global_step}') - selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( - workload, - input_queue, - optimizer_state, - model_params, - hyperparameters, - global_step, - data_select_rng) - logging.info(f'starting update {global_step}') - try: - optimizer_state, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types(), - model_state=model_state, - hyperparameters=hyperparameters, - input_batch=selected_train_input_batch, - label_batch=selected_train_label_batch, - mask_batch=selected_train_mask_batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=update_rng) - except spec.TrainingCompleteError: - training_complete = True - logging.info(f'finished step {global_step}') - global_step += 1 - # current_time = time.time() - # latest_eval_result = workload.eval_model(model_params, model_state, - # eval_rng, data_dir) - # logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' - # f'\t{latest_eval_result}') - # eval_results.append((global_step, latest_eval_result)) + while (is_time_remaining and not goal_reached and not training_complete): + step_rng = prng.fold_in(rng, global_step) + data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) + start_time = time.time() + logging.info(f'starting step {global_step}') + selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( + workload, + input_queue, + optimizer_state, + model_params, + hyperparameters, + global_step, + data_select_rng) + logging.info(f'starting update {global_step}') + try: + optimizer_state, model_params, model_state = update_params( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types(), + model_state=model_state, + hyperparameters=hyperparameters, + input_batch=selected_train_input_batch, + label_batch=selected_train_label_batch, + mask_batch=selected_train_mask_batch, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=update_rng) + except spec.TrainingCompleteError: + training_complete = True + logging.info(f'finished step {global_step}') + global_step += 1 + current_time = time.time() + accumulated_submission_time += current_time - start_time + is_time_remaining = ( + accumulated_submission_time < workload.max_allowed_runtime_sec) + # Check if submission is eligible for an untimed eval. + if (current_time - last_eval_time >= workload.eval_period_time_sec or + training_complete): + latest_eval_result = workload.eval_model(model_params, model_state, + eval_rng, data_dir) + logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' + f'\t{latest_eval_result}') + last_eval_time = current_time + eval_results.append((global_step, latest_eval_result)) + goal_reached = workload.has_reached_goal(latest_eval_result) jax.profiler.stop_trace() metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics From 0c57ca159eff15f9e16028aaac5df90bb65398c6 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 16 Feb 2022 16:04:20 -0500 Subject: [PATCH 31/46] fixing gnn workload limits and setup --- README.md | 4 ++-- algorithmic_efficiency/workloads/ogb/workload.py | 5 +++-- setup.cfg | 7 ++++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 3b466b838..3a6b4d549 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ **JAX (GPU)** ```bash - pip3 install -e .[jax-gpu] -f 'https://storage.googleapis.com/jax-releases/jax_releases.html' + pip3 install -e .[jax_gpu] -f 'https://storage.googleapis.com/jax-releases/jax_releases.html' ``` **JAX (CPU)** ```bash - pip3 install -e .[jax-cpu] + pip3 install -e .[jax_cpu] ``` **PyTorch** diff --git a/algorithmic_efficiency/workloads/ogb/workload.py b/algorithmic_efficiency/workloads/ogb/workload.py index ca8d36c64..ee2a7c06e 100644 --- a/algorithmic_efficiency/workloads/ogb/workload.py +++ b/algorithmic_efficiency/workloads/ogb/workload.py @@ -8,7 +8,8 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 0.25 + # From Flax example https://tensorboard.dev/experiment/AAJqfvgSRJaA1MBkc0jMWQ/#scalars. + return 0.24 @property def loss_type(self): @@ -36,4 +37,4 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 360 # 60 minutes (too long) + return 120 diff --git a/setup.cfg b/setup.cfg index 358e6ccd8..6eb992df9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ python_requires = >=3.7 [options.extras_require] # Add extra dependencies, e.g. to run tests or for the different frameworks. -# Use as `pip install -e '.[jax-gpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html` +# Use as `pip install -e '.[jax_gpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html` # or `pip install -e '.[dev]'` # Dependencies for developing the package @@ -54,6 +54,11 @@ dev = pytest yapf +# Graph NN workload. +gnn = + jraph==0.0.2.dev + scikit-learn==1.0.1 + # JAX Core jax_core_deps = flax==0.3.5 From 93459396c819e45d45910f78b00fcf86cff3d27b Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 16 Feb 2022 16:07:37 -0500 Subject: [PATCH 32/46] adding clu dep --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 6eb992df9..ec9ecdfda 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,6 +58,7 @@ dev = gnn = jraph==0.0.2.dev scikit-learn==1.0.1 + clu==0.0.6 # JAX Core jax_core_deps = From 65ffb99d064d35c450036ccc9b6bab96dfac5535 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 16 Feb 2022 16:27:01 -0500 Subject: [PATCH 33/46] no logging --- submission_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 60d55c4e2..8904389da 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -176,13 +176,13 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, training_complete = False global_start_time = time.time() - jax.profiler.start_trace("/tmp/tensorboard") + # jax.profiler.start_trace("/tmp/tensorboard") logging.info('Starting training loop.') while (is_time_remaining and not goal_reached and not training_complete): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) start_time = time.time() - logging.info(f'starting step {global_step}') + # logging.info(f'starting step {global_step}') selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( workload, input_queue, @@ -191,7 +191,7 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, hyperparameters, global_step, data_select_rng) - logging.info(f'starting update {global_step}') + # logging.info(f'starting update {global_step}') try: optimizer_state, model_params, model_state = update_params( workload=workload, @@ -225,7 +225,7 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, last_eval_time = current_time eval_results.append((global_step, latest_eval_result)) goal_reached = workload.has_reached_goal(latest_eval_result) - jax.profiler.stop_trace() + # jax.profiler.stop_trace() metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics From e687885a729112984b4c2d138ebc6c78c9fa46c0 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 16 Feb 2022 16:30:48 -0500 Subject: [PATCH 34/46] silence --- algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py | 5 ++--- submission_runner.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index e2cd14ff8..85c7c73b2 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -218,11 +218,10 @@ def eval_model( # Both val and test have the same (prime) number of examples. num_val_examples = 43793 total_eval_batch_size = eval_per_core_batch_size * jax.local_device_count() - # num_val_steps = num_val_examples // total_eval_batch_size + 1 DO NOT SUBMIT - num_val_steps = 1 + num_val_steps = num_val_examples // total_eval_batch_size + 1 # Loop over graph batches in eval dataset. for s in range(num_val_steps): - logging.info(f'eval step {s}') + # logging.info(f'eval step {s}') graphs, labels, masks = next(self._eval_iterator) batch_metrics = self._eval_batch( params, graphs, labels, masks, model_state, model_rng) diff --git a/submission_runner.py b/submission_runner.py index 8904389da..af44667b4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -209,7 +209,7 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, rng=update_rng) except spec.TrainingCompleteError: training_complete = True - logging.info(f'finished step {global_step}') + # logging.info(f'finished step {global_step}') global_step += 1 current_time = time.time() accumulated_submission_time += current_time - start_time From 1d31fb5dde4759db4357218d43198c0f4fb125d5 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Fri, 18 Feb 2022 12:51:29 -0500 Subject: [PATCH 35/46] adding in data aug --- .../workloads/ogb/ogb_jax/input_pipeline.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index ee6ee5fb1..6d163ae16 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -46,10 +46,37 @@ def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], num_edges = tf.squeeze(graph['num_edges']) nodes = graph['node_feat'] edges = graph['edge_feat'] + edge_feature_dim = edges.shape[-1] labels = graph['labels'] senders = graph['edge_index'][:, 0] receivers = graph['edge_index'][:, 1] + nodes = tf.concat( + [nodes, tf.zeros_like(nodes[0, None])], axis=0) + senders = tf.concat( + [senders, tf.range(num_nodes)], axis=0) + receivers = tf.concat( + [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) + edges = tf.concat( + [edges, tf.zeros(tf.stack([num_nodes, edge_feature_dim]))], axis=0) + num_edges += num_nodes + num_nodes += 1 + + # Make edges undirected, by adding edges with senders and receivers flipped. + # The feature vector for the flipped edge is the same as the original edge. + new_senders = tf.concat([senders, receivers], axis=0) + new_receivers = tf.concat([receivers, senders], axis=0) + edges = tf.concat([edges, edges], axis=0) + senders, receivers = new_senders, new_receivers + num_edges *= 2 + + # Add self-loops for each node. + # The feature vectors for the self-loops are set to all zeros. + senders = tf.concat([senders, tf.range(num_nodes)], axis=0) + receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) + edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) + num_edges += num_nodes + return jraph.GraphsTuple( n_node=tf.expand_dims(num_nodes, 0), n_edge=tf.expand_dims(num_edges, 0), From 3b506bf1f27b89089169120ba854ce14c566e4c0 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Mon, 28 Feb 2022 22:54:11 -0500 Subject: [PATCH 36/46] i2w version --- .../workloads/ogb/ogb_jax/input_pipeline.py | 291 +++++++++--------- .../workloads/ogb/ogb_jax/metrics.py | 10 +- .../workloads/ogb/ogb_jax/models.py | 211 ++++--------- .../workloads/ogb/ogb_jax/workload.py | 22 +- 4 files changed, 207 insertions(+), 327 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 6d163ae16..25d81dbc1 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -1,5 +1,7 @@ # Forked from Flax example which can be found here: # https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py +# and from the init2winit fork here +# https://github.com/google/init2winit/blob/master/init2winit/dataset_lib/ogbg_molpcba.py """Exposes the ogbg-molpcba dataset in a convenient format.""" @@ -9,9 +11,6 @@ import jraph import numpy as np import tensorflow as tf -# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.experimental.set_visible_devices([], 'GPU') import tensorflow_datasets as tfds @@ -26,159 +25,147 @@ class GraphsTupleSize(NamedTuple): n_graph: int -def _get_raw_dataset( - split_name: str, - data_dir: str, - file_shuffle_seed: Any) -> Dict[str, tf.data.Dataset]: - """Returns datasets as tf.data.Dataset, organized by split.""" - ds_builder = tfds.builder('ogbg_molpcba', data_dir=data_dir) - ds_builder.download_and_prepare() - config = tfds.ReadConfig(shuffle_seed=file_shuffle_seed) - return ds_builder.as_dataset(split=split_name, read_config=config) - - -def convert_to_graphs_tuple(graph: Dict[str, tf.Tensor], - add_virtual_node: bool, - add_undirected_edges: bool, - add_self_loops: bool) -> jraph.GraphsTuple: - """Converts a dictionary of tf.Tensors to a GraphsTuple.""" - num_nodes = tf.squeeze(graph['num_nodes']) - num_edges = tf.squeeze(graph['num_edges']) - nodes = graph['node_feat'] - edges = graph['edge_feat'] - edge_feature_dim = edges.shape[-1] - labels = graph['labels'] - senders = graph['edge_index'][:, 0] - receivers = graph['edge_index'][:, 1] - - nodes = tf.concat( - [nodes, tf.zeros_like(nodes[0, None])], axis=0) - senders = tf.concat( - [senders, tf.range(num_nodes)], axis=0) - receivers = tf.concat( - [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0) - edges = tf.concat( - [edges, tf.zeros(tf.stack([num_nodes, edge_feature_dim]))], axis=0) - num_edges += num_nodes - num_nodes += 1 - - # Make edges undirected, by adding edges with senders and receivers flipped. - # The feature vector for the flipped edge is the same as the original edge. - new_senders = tf.concat([senders, receivers], axis=0) - new_receivers = tf.concat([receivers, senders], axis=0) - edges = tf.concat([edges, edges], axis=0) - senders, receivers = new_senders, new_receivers - num_edges *= 2 - - # Add self-loops for each node. - # The feature vectors for the self-loops are set to all zeros. - senders = tf.concat([senders, tf.range(num_nodes)], axis=0) - receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) - edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) - num_edges += num_nodes +def _load_dataset(split, should_shuffle, data_rng, data_dir): + """Loads a dataset split from TFDS.""" + if should_shuffle: + file_data_rng, dataset_data_rng = jax.random.split(data_rng) + file_data_rng = file_data_rng[0] + dataset_data_rng = dataset_data_rng[0] + else: + file_data_rng = None + dataset_data_rng = None + + read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng) + dataset = tfds.load( + 'ogbg_molpcba', + split=split, + shuffle_files=should_shuffle, + read_config=read_config, + data_dir=data_dir) + + if should_shuffle: + dataset = dataset.shuffle( + seed=dataset_data_rng, buffer_size=2 ** 15) + dataset = dataset.repeat() + + return dataset + + +def _to_jraph(example): + """Converts an example graph to jraph.GraphsTuple.""" + example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access + edge_feat = example['edge_feat'] + node_feat = example['node_feat'] + edge_index = example['edge_index'] + labels = example['labels'] + num_nodes = example['num_nodes'] + + senders = edge_index[:, 0] + receivers = edge_index[:, 1] return jraph.GraphsTuple( - n_node=tf.expand_dims(num_nodes, 0), - n_edge=tf.expand_dims(num_edges, 0), - nodes=nodes, - edges=edges, - senders=senders, - receivers=receivers, - globals=tf.expand_dims(labels, axis=0)) - - -def _get_valid_mask(graphs: jraph.GraphsTuple): - """Gets the binary mask indicating only valid labels and graphs.""" - labels = graphs.globals - # We have to ignore all NaN values - which indicate labels for which - # the current graphs have no label. - labels_masks = ~np.isnan(labels) - - # Since we have extra 'dummy' graphs in our batch due to padding, we want - # to mask out any loss associated with the dummy graphs. - # Since we padded with `pad_with_graphs` we can recover the mask by using - # get_graph_padding_mask. - graph_masks = jraph.get_graph_padding_mask(graphs) - - # Combine the mask over labels with the mask over graphs. - masks = labels_masks & graph_masks[:, None] - graphs = graphs._replace(globals=[]) - return graphs, labels, masks - - -def _batch_for_pmap(iterator): - graphs = [] - labels = [] - masks = [] + n_node=num_nodes, + n_edge=np.array([len(edge_index) * 2]), + nodes=node_feat, + edges=np.concatenate([edge_feat, edge_feat]), + # Make the edges bidirectional + senders=np.concatenate([senders, receivers]), + receivers=np.concatenate([receivers, senders]), + # Keep the labels with the graph for batching. They will be removed + # in the processed batch. + globals=np.expand_dims(labels, axis=0)) + + +def _get_weights_by_nan_and_padding(labels, padding_mask): + """Handles NaNs and padding in labels. + + Sets all the weights from examples coming from padding to 0. Changes all NaNs + in labels to 0s and sets the corresponding per-label weight to 0. + + Args: + labels: Labels including labels from padded examples + padding_mask: Binary array of which examples are padding + Returns: + tuple of (processed labels, corresponding weights) + """ + nan_mask = np.isnan(labels) + replaced_labels = np.copy(labels) + np.place(replaced_labels, nan_mask, 0) + + weights = 1.0 - nan_mask + # Weights for all labels of a padded element will be 0 + weights = weights * padding_mask[:, None] + return replaced_labels, weights + + +def _get_batch_iterator(dataset_iter, batch_size, num_shards=None): + """Turns a per-example iterator into a batched iterator. + + Constructs the batch from num_shards smaller batches, so that we can easily + shard the batch to multiple devices during training. We use + dynamic batching, so we specify some max number of graphs/nodes/edges, add + as many graphs as we can, and then pad to the max values. + + Args: + dataset_iter: The TFDS dataset iterator. + batch_size: How many average-sized graphs go into the batch. + num_shards: How many devices we should be able to shard the batch into. + Yields: + Batch in the init2winit format. Each field is a list of num_shards separate + smaller batches. + """ + if not num_shards: + num_shards = jax.device_count() + + # We will construct num_shards smaller batches and then put them together. + batch_size /= num_shards + + max_n_nodes = AVG_NODES_PER_GRAPH * batch_size + max_n_edges = AVG_EDGES_PER_GRAPH * batch_size + max_n_graphs = batch_size + + jraph_iter = map(_to_jraph, dataset_iter) + batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1, + max_n_edges, max_n_graphs + 1) + count = 0 - for batch in iterator: - graph_batch, label_batch, mask_batch = _get_valid_mask(batch) + graphs_shards = [] + labels_shards = [] + weights_shards = [] + + for batched_graph in batched_iter: count += 1 - graphs.append(graph_batch) - labels.append(label_batch) - masks.append(mask_batch) - if count == jax.local_device_count(): - graphs = jax.tree_multimap(lambda *x: np.stack(x, axis=0), *graphs) - labels = np.stack(labels) - masks = np.stack(masks) - yield graphs, labels, masks - graphs = [] - labels = [] - masks = [] - count = 0 + # Separate the labels from the graph + labels = batched_graph.globals + graph = batched_graph._replace(globals={}) -def get_dataset_iter(split_name: str, - data_rng: jax.random.PRNGKey, - data_dir: str, - batch_size: int, - add_virtual_node: bool = True, - add_undirected_edges: bool = True, - add_self_loops: bool = True) -> Dict[str, tf.data.Dataset]: - """Returns datasets of batched GraphsTuples, organized by split.""" - if batch_size <= 1: - raise ValueError('Batch size must be > 1 to account for padding graphs.') - - file_shuffle_seed, dataset_shuffle_seed = jax.random.split(data_rng) - file_shuffle_seed = file_shuffle_seed[0] - dataset_shuffle_seed = dataset_shuffle_seed[0] - - # Obtain the original datasets. - dataset = _get_raw_dataset(split_name, data_dir, file_shuffle_seed) - - # Construct the GraphsTuple converter function. - convert_to_graphs_tuple_fn = functools.partial( - convert_to_graphs_tuple, - add_virtual_node=add_self_loops, - add_undirected_edges=add_undirected_edges, - add_self_loops=add_virtual_node, - ) - - dataset = dataset.map( - convert_to_graphs_tuple_fn, - num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) - - # Repeat and shuffle the training split. - if split_name == 'train': - dataset = dataset.shuffle( - buffer_size=2**15, - seed=dataset_shuffle_seed, - reshuffle_each_iteration=True) - dataset = dataset.repeat() - # We do not need to cache the validation and test sets because we do this - # later with itertools.cycle. + replaced_labels, weights = _get_weights_by_nan_and_padding( + labels, jraph.get_graph_padding_mask(graph)) - # Batch and pad each split. Note that this also converts the graphs to numpy. - max_n_nodes = AVG_NODES_PER_GRAPH * batch_size - max_n_edges = AVG_EDGES_PER_GRAPH * batch_size - batched_iter = jraph.dynamically_batch( - graphs_tuple_iterator=iter(dataset), - n_node=max_n_nodes, - n_edge=max_n_edges, - n_graph=batch_size) - - # An iterator the same as above, but where each element has an extra leading - # dim of size jax.local_device_count(). - pmapped_iterator = _batch_for_pmap(batched_iter) - return pmapped_iterator + graphs_shards.append(graph) + labels_shards.append(replaced_labels) + weights_shards.append(weights) + + if count == num_shards: + def f(x): + return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) + + graphs_shards = f(graphs_shards) + labels_shards = f(labels_shards) + weights_shards = f(weights_shards) + yield (graphs_shards, labels_shards, weights_shards) + + count = 0 + graphs_shards = [] + labels_shards = [] + weights_shards = [] + + +def get_dataset_iter(split, data_rng, data_dir, batch_size): + ds = _load_dataset( + split, + should_shuffle=(split == 'train'), + data_rng=data_rng, + data_dir=data_dir) + return _get_batch_iterator(iter(ds), batch_size) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py index 7410500b8..3a6fccf2e 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py @@ -17,9 +17,11 @@ def predictions_match_labels(*, logits: jnp.ndarray, labels: jnp.ndarray, return (preds == labels).astype(jnp.float32) +# Following the Flax OGB example: +# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py @flax.struct.dataclass class MeanAveragePrecision( - metrics.CollectingMetric.from_outputs(('labels', 'logits', 'mask'))): + metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): """Computes the mean average precision (mAP) over different tasks.""" def compute(self): @@ -28,8 +30,7 @@ def compute(self): logits = self.values['logits'] mask = self.values['mask'] - assert logits.shape == labels.shape == mask.shape - assert len(logits.shape) == 2 + mask = mask.astype(np.bool) probs = jax.nn.sigmoid(logits) num_tasks = labels.shape[1] @@ -38,7 +39,8 @@ def compute(self): for task in range(num_tasks): # AP is only defined when there is at least one negative data # and at least one positive data. - if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: + if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, + task] == 1) > 0: is_labeled = mask[:, task] average_precisions[task] = average_precision_score( labels[is_labeled, task], probs[is_labeled, task]) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py index 80ba51bde..88aa7b1a9 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py @@ -1,176 +1,77 @@ -# Forked from Flax example which can be found here: -# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/models.py - -"""Definition of the GNN model.""" - -from typing import Callable, Sequence +# Forked from the init2winit implementation here +# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from typing import Tuple from flax import linen as nn import jax.numpy as jnp import jraph -def add_graphs_tuples(graphs: jraph.GraphsTuple, - other_graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: - """Adds the nodes, edges and global features from other_graphs to graphs.""" - return graphs._replace( - nodes=graphs.nodes + other_graphs.nodes, - edges=graphs.edges + other_graphs.edges, - globals=graphs.globals + other_graphs.globals) +def _make_embed(latent_dim): + def make_fn(inputs): + return nn.Dense(features=latent_dim)(inputs) -class MLP(nn.Module): - """A multi-layer perceptron.""" + return make_fn - feature_sizes: Sequence[int] - dropout_rate: float = 0 - deterministic: bool = True - activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - @nn.compact - def __call__(self, inputs): +def _make_mlp(hidden_dims, dropout): + """Creates a MLP with specified dimensions.""" + + @jraph.concatenated_args + def make_fn(inputs): x = inputs - for size in self.feature_sizes: - x = nn.Dense(features=size)(x) - x = self.activation(x) - x = nn.Dropout( - rate=self.dropout_rate, deterministic=self.deterministic)(x) + for dim in hidden_dims: + x = nn.Dense(features=dim)(x) + x = nn.LayerNorm()(x) + x = nn.relu(x) + x = dropout(x) return x + return make_fn -class GraphNet(nn.Module): - """A complete Graph Network model defined with Jraph.""" - latent_size: int - num_mlp_layers: int - message_passing_steps: int - output_globals_size: int - dropout_rate: float = 0 - skip_connections: bool = True - use_edge_model: bool = True - layer_norm: bool = True +class GNN(nn.Module): + """Defines a graph network. + The model assumes the input data is a jraph.GraphsTuple without global + variables. The final prediction will be encoded in the globals. + """ + num_outputs: int = 128 + latent_dim: int = 256 + hidden_dims: Tuple[int] = (256,) + dropout_rate: float = 0.1 + num_message_passing_steps: int = 5 @nn.compact - def __call__(self, graphs: jraph.GraphsTuple, train: bool) -> jraph.GraphsTuple: - # We will first linearly project the original features as 'embeddings'. - embedder = jraph.GraphMapFeatures( - embed_node_fn=nn.Dense(self.latent_size), - embed_edge_fn=nn.Dense(self.latent_size), - embed_global_fn=nn.Dense(self.latent_size)) - processed_graphs = embedder(graphs) - - # Now, we will apply a Graph Network once for each message-passing round. - mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers - for _ in range(self.message_passing_steps): - if self.use_edge_model: - update_edge_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, - dropout_rate=self.dropout_rate, - deterministic=not train)) - else: - update_edge_fn = None - - update_node_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, - dropout_rate=self.dropout_rate, - deterministic=not train)) - update_global_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, - dropout_rate=self.dropout_rate, - deterministic=not train)) - - graph_net = jraph.GraphNetwork( - update_node_fn=update_node_fn, - update_edge_fn=update_edge_fn, - update_global_fn=update_global_fn) - - if self.skip_connections: - processed_graphs = add_graphs_tuples( - graph_net(processed_graphs), processed_graphs) - else: - processed_graphs = graph_net(processed_graphs) - - if self.layer_norm: - processed_graphs = processed_graphs._replace( - nodes=nn.LayerNorm()(processed_graphs.nodes), - edges=nn.LayerNorm()(processed_graphs.edges), - globals=nn.LayerNorm()(processed_graphs.globals), - ) - - # Since our graph-level predictions will be at globals, we will - # decode to get the required output logits. - decoder = jraph.GraphMapFeatures( - embed_global_fn=nn.Dense(self.output_globals_size)) - processed_graphs = decoder(processed_graphs) - - return processed_graphs - - -class GraphConvNet(nn.Module): - """A Graph Convolution Network + Pooling model defined with Jraph.""" - - latent_size: int - num_mlp_layers: int - message_passing_steps: int - output_globals_size: int - dropout_rate: float = 0 - skip_connections: bool = True - layer_norm: bool = True - pooling_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], - jnp.ndarray] = jraph.segment_mean - - def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: - """Pooling operation, taken from Jraph.""" - - # Equivalent to jnp.sum(n_node), but JIT-able. - sum_n_node = graphs.nodes.shape[0] - # To aggregate nodes from each graph to global features, - # we first construct tensors that map the node to the corresponding graph. - # Example: if you have `n_node=[1,2]`, we construct the tensor [0, 1, 1]. - n_graph = graphs.n_node.shape[0] - node_graph_indices = jnp.repeat( - jnp.arange(n_graph), - graphs.n_node, - axis=0, - total_repeat_length=sum_n_node) - # We use the aggregation function to pool the nodes per graph. - pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) - return graphs._replace(globals=pooled) + def __call__(self, graph, train): + dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train) + + graph = graph._replace( + globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) - @nn.compact - def __call__(self, graphs: jraph.GraphsTuple, train: bool) -> jraph.GraphsTuple: - # We will first linearly project the original node features as 'embeddings'. embedder = jraph.GraphMapFeatures( - embed_node_fn=nn.Dense(self.latent_size)) - processed_graphs = embedder(graphs) - - # Now, we will apply the GCN once for each message-passing round. - for _ in range(self.message_passing_steps): - mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers - update_node_fn = jraph.concatenated_args( - MLP(mlp_feature_sizes, - dropout_rate=self.dropout_rate, - deterministic=not train)) - graph_conv = jraph.GraphConvolution( - update_node_fn=update_node_fn, add_self_edges=True) - - if self.skip_connections: - processed_graphs = add_graphs_tuples( - graph_conv(processed_graphs), processed_graphs) - else: - processed_graphs = graph_conv(processed_graphs) - - if self.layer_norm: - processed_graphs = processed_graphs._replace( - nodes=nn.LayerNorm()(processed_graphs.nodes), - ) - - # We apply the pooling operation to get a 'global' embedding. - processed_graphs = self.pool(processed_graphs) - - # Now, we decode this to get the required output logits. + embed_node_fn=_make_embed(self.latent_dim), + embed_edge_fn=_make_embed(self.latent_dim)) + graph = embedder(graph) + + for _ in range(self.num_message_passing_steps): + net = jraph.GraphNetwork( + update_edge_fn=_make_mlp( + self.hidden_dims, + dropout=dropout), + update_node_fn=_make_mlp( + self.hidden_dims, + dropout=dropout), + update_global_fn=_make_mlp( + self.hidden_dims, + dropout=dropout)) + + graph = net(graph) + + # Map globals to represent the final result decoder = jraph.GraphMapFeatures( - embed_global_fn=nn.Dense(self.output_globals_size)) - processed_graphs = decoder(processed_graphs) + embed_global_fn=nn.Dense(self.num_outputs)) + graph = decoder(graph) + + return graph.globals - return processed_graphs diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 85c7c73b2..a1a73d539 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -25,14 +25,7 @@ def __init__(self): self._eval_iterator = None self._param_shapes = None self._init_graphs = None - self._model = models.GraphConvNet( - latent_size=256, - num_mlp_layers=2, - message_passing_steps=5, - output_globals_size=128, - dropout_rate=0.1, - skip_connections=True, - layer_norm=True) + self._model = models.GNN() def _build_iterator( self, @@ -44,12 +37,9 @@ def _build_iterator( split, data_rng, data_dir, - batch_size, - add_virtual_node=False, - add_undirected_edges=True, - add_self_loops=True) + batch_size) if self._init_graphs is None: - init_graphs = next(dataset_iter)[0] + init_graphs, _, _ = next(dataset_iter) # Unreplicate the iterator that has the leading dim for pmapping. self._init_graphs = jax.tree_map(lambda x: x[0], init_graphs) return dataset_iter @@ -107,7 +97,8 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: 'called before workload.init_model_fn()!' ) rng, params_rng, dropout_rng = jax.random.split(rng, 3) - params = jax.jit(functools.partial(self._model.init, train=False))( + init_fn = jax.jit(functools.partial(self._model.init, train=False)) + params = init_fn( {'params': params_rng, 'dropout': dropout_rng}, self._init_graphs) params = params['params'] self._param_shapes = jax.tree_map( @@ -138,12 +129,11 @@ def model_fn( """Get predicted logits from the network for input graphs.""" assert model_state is None train = mode == spec.ForwardPassMode.TRAIN - pred_graphs = self._model.apply( + logits = self._model.apply( {'params': params}, input_batch, rngs={'dropout': rng}, train=train) - logits = pred_graphs.globals return logits, None def _binary_cross_entropy_with_mask( From 59b2c0bc5d1b64aa971a6b5d31d48d2b648a6c71 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Tue, 1 Mar 2022 00:43:27 -0500 Subject: [PATCH 37/46] cleaning up, adding more yapf config --- .../workloads/ogb/ogb_jax/input_pipeline.py | 10 +-- .../workloads/ogb/ogb_jax/metrics.py | 3 +- .../workloads/ogb/ogb_jax/workload.py | 10 +-- .../workloads/ogb/workload.py | 3 +- baselines/ogb/ogb_jax/submission.py | 61 ++++++++++--------- .../ogb/ogb_jax/tuning_search_space.json | 2 +- setup.cfg | 2 + submission_runner.py | 26 +++----- 8 files changed, 50 insertions(+), 67 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 25d81dbc1..9481b91d1 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -5,8 +5,7 @@ """Exposes the ogbg-molpcba dataset in a convenient format.""" -import functools -from typing import Any, Dict, NamedTuple +from typing import NamedTuple import jax import jraph import numpy as np @@ -18,13 +17,6 @@ AVG_EDGES_PER_GRAPH = 56 -class GraphsTupleSize(NamedTuple): - """Helper class to represent padding and graph sizes.""" - n_node: int - n_edge: int - n_graph: int - - def _load_dataset(split, should_shuffle, data_rng, data_dir): """Loads a dataset split from TFDS.""" if should_shuffle: diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py index 3a6fccf2e..767a95251 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py @@ -39,8 +39,7 @@ def compute(self): for task in range(num_tasks): # AP is only defined when there is at least one negative data # and at least one positive data. - if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, - task] == 1) > 0: + if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: is_labeled = mask[:, task] average_precisions[task] = average_precision_score( labels[is_labeled, task], probs[is_labeled, task]) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index a1a73d539..0be970cba 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -1,15 +1,11 @@ """OGB workload implemented in Jax.""" -from absl import logging from typing import Optional, Tuple + import functools -import numpy as np -from flax import linen as nn from flax import jax_utils import itertools import jax import jax.numpy as jnp -import jraph -import sklearn.metrics from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency import spec @@ -127,6 +123,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" + del update_batch_norm # No BN in the GNN model. assert model_state is None train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( @@ -210,8 +207,7 @@ def eval_model( total_eval_batch_size = eval_per_core_batch_size * jax.local_device_count() num_val_steps = num_val_examples // total_eval_batch_size + 1 # Loop over graph batches in eval dataset. - for s in range(num_val_steps): - # logging.info(f'eval step {s}') + for _ in range(num_val_steps): graphs, labels, masks = next(self._eval_iterator) batch_metrics = self._eval_batch( params, graphs, labels, masks, model_state, model_rng) diff --git a/algorithmic_efficiency/workloads/ogb/workload.py b/algorithmic_efficiency/workloads/ogb/workload.py index ee2a7c06e..956279655 100644 --- a/algorithmic_efficiency/workloads/ogb/workload.py +++ b/algorithmic_efficiency/workloads/ogb/workload.py @@ -8,7 +8,8 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - # From Flax example https://tensorboard.dev/experiment/AAJqfvgSRJaA1MBkc0jMWQ/#scalars. + # From Flax example + # https://tensorboard.dev/experiment/AAJqfvgSRJaA1MBkc0jMWQ/#scalars. return 0.24 @property diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogb/ogb_jax/submission.py index a9044994f..c827b2ba0 100644 --- a/baselines/ogb/ogb_jax/submission.py +++ b/baselines/ogb/ogb_jax/submission.py @@ -1,12 +1,9 @@ from typing import Iterator, List, Optional, Tuple -import functools -import numpy as np import jax import jax.numpy as jnp from jax import lax from flax import jax_utils -import jraph import optax from algorithmic_efficiency import spec @@ -17,36 +14,34 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def optimizer(hyperparameters: spec.Hyperparamters) -> optax.GradientTransformation: - """Creates an optimizer.""" - opt_init_fn, opt_update_fn = optax.adam( - learning_rate=hyperparameters.learning_rate) - return opt_init_fn, opt_update_fn - - def init_optimizer_state( workload: spec.Workload, model_params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, rng: spec.RandomState) -> spec.OptimizerState: + """Creates an Adam optimizer.""" + del model_params + del model_state + del rng params_zeros_like = jax.tree_map( lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters) - init_optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(init_optimizer_state), opt_update_fn - - -# We need to jax.pmap here instead of inside update_params because the latter -# would recompile the function every step. -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, input_batch, - label_batch, mask_batch, rng): + opt_init_fn, opt_update_fn = opt_init_fn, opt_update_fn = optax.adam( + learning_rate=hyperparameters.learning_rate) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + input_batch, + label_batch, + mask_batch, + rng): del hyperparameters def loss_fn(params): @@ -70,7 +65,6 @@ def loss_fn(params): return new_model_state, new_optimizer_state, updated_params -@jax.profiler.annotate_function def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, @@ -93,15 +87,16 @@ def update_params( del global_step optimizer_state, opt_update_fn = optimizer_state + pmapped_train_step = jax.pmap( + train_step, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0, 0, None), + static_broadcasted_argnums=(0, 1)) new_model_state, new_optimizer_state, new_params = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, current_param_container, hyperparameters, input_batch, label_batch, mask_batch, rng) - #steps_per_epoch = workload.num_train_examples // get_batch_size('ogb_jax') - #if (global_step + 1) % steps_per_epoch == 0: - # # sync batch statistics across replicas once per epoch - # new_model_state = workload.sync_batch_stats(new_model_state) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -114,4 +109,10 @@ def data_selection( global_step: int, rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue.""" + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng return next(input_queue) diff --git a/baselines/ogb/ogb_jax/tuning_search_space.json b/baselines/ogb/ogb_jax/tuning_search_space.json index 7aba31610..d50cc00c5 100644 --- a/baselines/ogb/ogb_jax/tuning_search_space.json +++ b/baselines/ogb/ogb_jax/tuning_search_space.json @@ -1 +1 @@ -{"learning_rate": {"feasible_points": [1e-3]}} \ No newline at end of file +{"learning_rate": {"feasible_points": [1e-3]}} diff --git a/setup.cfg b/setup.cfg index ec9ecdfda..d171ef1cf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -88,6 +88,8 @@ pytorch = # yapf configuration [yapf] based_on_style = yapf +each_dict_entry_on_separate_line = false +split_all_top_level_comma_separated_values = true # isort configuration [isort] diff --git a/submission_runner.py b/submission_runner.py index af44667b4..632b1d267 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -12,11 +12,8 @@ """ import importlib import inspect -import jax import json import os -# Enable flax xprof trace labelling. -os.environ['FLAX_PROFILE'] = 'true' import struct import time from typing import Optional, Tuple @@ -35,34 +32,34 @@ WORKLOADS = { 'mnist_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'mnist/mnist_jax/workload.py', - 'workload_class_name': 'MnistWorkload' + 'workload_class_name': 'MnistWorkload', }, 'mnist_pytorch': { 'workload_path': BASE_WORKLOADS_DIR + 'mnist/mnist_pytorch/workload.py', - 'workload_class_name': 'MnistWorkload' + 'workload_class_name': 'MnistWorkload', }, 'imagenet_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'imagenet/imagenet_jax/workload.py', - 'workload_class_name': 'ImagenetWorkload' + 'workload_class_name': 'ImagenetWorkload', }, 'imagenet_pytorch': { 'workload_path': BASE_WORKLOADS_DIR + 'imagenet/imagenet_pytorch/workload.py', - 'workload_class_name': 'ImagenetWorkload' + 'workload_class_name': 'ImagenetWorkload', }, 'ogb_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'ogb/ogb_jax/workload.py', - 'workload_class_name': 'OGBWorkload' + 'workload_class_name': 'OGBWorkload', }, 'wmt_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'wmt/wmt_jax/workload.py', - 'workload_class_name': 'WMTWorkload' + 'workload_class_name': 'WMTWorkload', }, 'librispeech_pytorch': { 'workload_path': BASE_WORKLOADS_DIR + 'librispeech/librispeech_pytorch/workload.py', - 'workload_class_name': 'LibriSpeechWorkload' + 'workload_class_name': 'LibriSpeechWorkload', } } @@ -176,13 +173,11 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, training_complete = False global_start_time = time.time() - # jax.profiler.start_trace("/tmp/tensorboard") logging.info('Starting training loop.') while (is_time_remaining and not goal_reached and not training_complete): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) start_time = time.time() - # logging.info(f'starting step {global_step}') selected_train_input_batch, selected_train_label_batch, selected_train_mask_batch = data_selection( workload, input_queue, @@ -191,7 +186,6 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, hyperparameters, global_step, data_select_rng) - # logging.info(f'starting update {global_step}') try: optimizer_state, model_params, model_state = update_params( workload=workload, @@ -209,7 +203,6 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, rng=update_rng) except spec.TrainingCompleteError: training_complete = True - # logging.info(f'finished step {global_step}') global_step += 1 current_time = time.time() accumulated_submission_time += current_time - start_time @@ -221,11 +214,10 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, latest_eval_result = workload.eval_model(model_params, model_state, eval_rng, data_dir) logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' - f'\t{latest_eval_result}') + f'\t{latest_eval_result}') last_eval_time = current_time eval_results.append((global_step, latest_eval_result)) goal_reached = workload.has_reached_goal(latest_eval_result) - # jax.profiler.stop_trace() metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics @@ -239,7 +231,7 @@ def score_submission_on_workload( tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None): # Remove the trailing '.py' and convert the filepath to a Python module. - submission_module_path = _convert_filepath_to_module(FLAGS.submission_path) + submission_module_path = _convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) init_optimizer_state = submission_module.init_optimizer_state From b83dc8f621fd672c5f9d55c2d95599ac001d81d9 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Tue, 1 Mar 2022 00:48:05 -0500 Subject: [PATCH 38/46] final cleanup --- algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py | 3 +-- baselines/ogb/ogb_jax/submission.py | 1 + setup.py | 2 +- submission_runner.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py index 767a95251..f234e611a 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py @@ -17,8 +17,6 @@ def predictions_match_labels(*, logits: jnp.ndarray, labels: jnp.ndarray, return (preds == labels).astype(jnp.float32) -# Following the Flax OGB example: -# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py @flax.struct.dataclass class MeanAveragePrecision( metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): @@ -36,6 +34,7 @@ def compute(self): num_tasks = labels.shape[1] average_precisions = np.full(num_tasks, np.nan) + # Note that this code is slow (~1 minute). for task in range(num_tasks): # AP is only defined when there is at least one negative data # and at least one positive data. diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogb/ogb_jax/submission.py index c827b2ba0..0646f1af0 100644 --- a/baselines/ogb/ogb_jax/submission.py +++ b/baselines/ogb/ogb_jax/submission.py @@ -10,6 +10,7 @@ def get_batch_size(workload_name): + # Return the per-device batch size. batch_sizes = {'ogb_jax': 256} return batch_sizes[workload_name] diff --git a/setup.py b/setup.py index 26e08e48e..3fa0c489a 100644 --- a/setup.py +++ b/setup.py @@ -2,4 +2,4 @@ if __name__ == "__main__": - setup() + setup() diff --git a/submission_runner.py b/submission_runner.py index 632b1d267..89fd239a0 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -212,7 +212,7 @@ def train_once(workload: spec.Workload, batch_size: int, data_dir: str, if (current_time - last_eval_time >= workload.eval_period_time_sec or training_complete): latest_eval_result = workload.eval_model(model_params, model_state, - eval_rng, data_dir) + eval_rng, data_dir) logging.info(f'{current_time - global_start_time:.2f}s\t{global_step}' f'\t{latest_eval_result}') last_eval_time = current_time From 778e6834fb12b62276faa80e403ed46ca986f76d Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Mar 2022 01:34:45 -0500 Subject: [PATCH 39/46] fixing pylint --- .../workloads/ogb/ogb_jax/input_pipeline.py | 2 -- algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index 9481b91d1..eccf33c6f 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -5,11 +5,9 @@ """Exposes the ogbg-molpcba dataset in a convenient format.""" -from typing import NamedTuple import jax import jraph import numpy as np -import tensorflow as tf import tensorflow_datasets as tfds diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 0be970cba..d14e9e610 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -117,7 +117,7 @@ def loss_type(self): def model_fn( self, params: spec.ParameterContainer, - input_batch: spec.Tensor, + augmented_and_preprocessed_input_batch: spec.Tensor, model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, @@ -128,7 +128,7 @@ def model_fn( train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( {'params': params}, - input_batch, + augmented_and_preprocessed_input_batch, rngs={'dropout': rng}, train=train) return logits, None From ba98ac7ac154ab996655ee6bc0f4163638fb341c Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Mar 2022 01:36:40 -0500 Subject: [PATCH 40/46] moving and fixing submission runner test --- tests/test_submission_runner.py => submission_runner_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) rename tests/test_submission_runner.py => submission_runner_test.py (81%) diff --git a/tests/test_submission_runner.py b/submission_runner_test.py similarity index 81% rename from tests/test_submission_runner.py rename to submission_runner_test.py index 24fddb44a..a79d0499d 100644 --- a/tests/test_submission_runner.py +++ b/submission_runner_test.py @@ -2,8 +2,7 @@ import os -from algorithmic_efficiency.submission_runner import \ - _convert_filepath_to_module +from submission_runner import _convert_filepath_to_module def test_convert_filepath_to_module(): From 548cb13555af22af97ec00bfa2172847b5683701 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Mar 2022 01:45:20 -0500 Subject: [PATCH 41/46] fixing import ordering --- .github/workflows/linting.yml | 4 ++-- .../workloads/ogb/ogb_jax/input_pipeline.py | 1 - algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py | 6 +++--- algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py | 8 ++++---- baselines/ogb/ogb_jax/submission.py | 4 ++-- setup.py | 1 - submission_runner.py | 3 +-- 7 files changed, 12 insertions(+), 15 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index b3395882a..1127b2b16 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -35,7 +35,7 @@ jobs: pip install isort - name: Run isort run: | - isort . --check + isort . --check --diff yapf: runs-on: ubuntu-latest @@ -51,4 +51,4 @@ jobs: pip install yapf - name: Run yapf run: | - yapf . --diff --recursive \ No newline at end of file + yapf . --diff --recursive diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index eccf33c6f..b9073f8a8 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -10,7 +10,6 @@ import numpy as np import tensorflow_datasets as tfds - AVG_NODES_PER_GRAPH = 26 AVG_EDGES_PER_GRAPH = 56 diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py index f234e611a..5197c1345 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py @@ -1,11 +1,11 @@ # Forked from Flax example which can be found here: # https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py -import numpy as np +from clu import metrics +import flax import jax import jax.numpy as jnp -import flax -from clu import metrics +import numpy as np from sklearn.metrics import average_precision_score diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index d14e9e610..690c0eaf2 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -1,18 +1,18 @@ """OGB workload implemented in Jax.""" +import functools +import itertools from typing import Optional, Tuple -import functools from flax import jax_utils -import itertools import jax import jax.numpy as jnp from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogb.workload import OGB from algorithmic_efficiency.workloads.ogb.ogb_jax import input_pipeline -from algorithmic_efficiency.workloads.ogb.ogb_jax import models from algorithmic_efficiency.workloads.ogb.ogb_jax import metrics +from algorithmic_efficiency.workloads.ogb.ogb_jax import models +from algorithmic_efficiency.workloads.ogb.workload import OGB class OGBWorkload(OGB): diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogb/ogb_jax/submission.py index 0646f1af0..23b71ba39 100644 --- a/baselines/ogb/ogb_jax/submission.py +++ b/baselines/ogb/ogb_jax/submission.py @@ -1,9 +1,9 @@ from typing import Iterator, List, Optional, Tuple +from flax import jax_utils import jax -import jax.numpy as jnp from jax import lax -from flax import jax_utils +import jax.numpy as jnp import optax from algorithmic_efficiency import spec diff --git a/setup.py b/setup.py index 3fa0c489a..a4ead8f48 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ from setuptools import setup - if __name__ == "__main__": setup() diff --git a/submission_runner.py b/submission_runner.py index 8710381dc..d209e2a71 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -24,9 +24,8 @@ import tensorflow as tf from algorithmic_efficiency import halton -from algorithmic_efficiency import spec from algorithmic_efficiency import random_utils as prng - +from algorithmic_efficiency import spec # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. From fcbcfe9cdd5e69dd28e9a5ab3a0c6a3ca68519e2 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Mar 2022 01:50:45 -0500 Subject: [PATCH 42/46] fixing lint GH action config, fixing formatting via yapf --- .github/workflows/linting.yml | 2 +- algorithmic_efficiency/spec.py | 50 ++++---- .../workloads/librispeech/prepare_data.py | 10 +- .../workloads/ogb/ogb_jax/input_pipeline.py | 13 +- .../workloads/ogb/ogb_jax/metrics.py | 5 +- .../workloads/ogb/ogb_jax/models.py | 16 +-- .../workloads/ogb/ogb_jax/workload.py | 119 +++++++++--------- baselines/ogb/ogb_jax/submission.py | 15 ++- submission_runner.py | 25 ++-- 9 files changed, 119 insertions(+), 136 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 1127b2b16..3c6c06451 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -19,7 +19,7 @@ jobs: run: | pylint algorithmic_efficiency pylint baselines - pylint tests + pylint submission_runner_test.py isort: runs-on: ubuntu-latest diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index e74cf224e..3c7729d15 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -67,32 +67,30 @@ def __init__(self, shape_tuple): UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState] InitOptimizerFn = Callable[[ParameterShapeTree, Hyperparamters, RandomState], OptimizerState] -UpdateParamsFn = Callable[ - [ - ParameterContainer, - ParameterTypeTree, - ModelAuxiliaryState, - Hyperparamters, - Tensor, - Tensor, - LossType, - OptimizerState, - List[Tuple[int, float]], - int, - RandomState - ], - UpdateReturn] -DataSelectionFn = Callable[ - [ - Iterator[Tuple[Tensor, Tensor]], - OptimizerState, - ParameterContainer, - LossType, - Hyperparamters, - int, - RandomState - ], - Tuple[Tensor, Tensor]] +UpdateParamsFn = Callable[[ + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparamters, + Tensor, + Tensor, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] +DataSelectionFn = Callable[[ + Iterator[Tuple[Tensor, Tensor]], + OptimizerState, + ParameterContainer, + LossType, + Hyperparamters, + int, + RandomState +], + Tuple[Tensor, Tensor]] class Workload(metaclass=abc.ABCMeta): diff --git a/algorithmic_efficiency/workloads/librispeech/prepare_data.py b/algorithmic_efficiency/workloads/librispeech/prepare_data.py index 25f1ce421..046f14cae 100644 --- a/algorithmic_efficiency/workloads/librispeech/prepare_data.py +++ b/algorithmic_efficiency/workloads/librispeech/prepare_data.py @@ -29,9 +29,8 @@ def analyze_transcripts(train_data_dir, ignore_space=False): if i % 10 == 0: print(i) for chapter_folder in os.listdir(f'{train_data_dir}/{speaker_folder}'): - trans_file = ( - f'{train_data_dir}/{speaker_folder}/{chapter_folder}/' - f'{speaker_folder}-{chapter_folder}.trans.txt') + trans_file = (f'{train_data_dir}/{speaker_folder}/{chapter_folder}/' + f'{speaker_folder}-{chapter_folder}.trans.txt') with open(trans_file, 'r') as f: for line in f: _, trans = line.strip().split(' ', maxsplit=1) @@ -57,9 +56,8 @@ def get_txt(data_dir, labels_dict, ignore_space=False): if not speaker_folder.isdigit(): continue for chapter_folder in os.listdir(f'{data_dir}/{speaker_folder}'): - trans_file = ( - f'{data_dir}/{speaker_folder}/{chapter_folder}/' - f'{speaker_folder}-{chapter_folder}.trans.txt') + trans_file = (f'{data_dir}/{speaker_folder}/{chapter_folder}/' + f'{speaker_folder}-{chapter_folder}.trans.txt') with open(trans_file, 'r') as f: for l in f: utt, trans = l.strip().split(' ', maxsplit=1) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py index b9073f8a8..e0299122f 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py @@ -2,7 +2,6 @@ # https://github.com/google/flax/blob/main/examples/ogbg_molpcba/input_pipeline.py # and from the init2winit fork here # https://github.com/google/init2winit/blob/master/init2winit/dataset_lib/ogbg_molpcba.py - """Exposes the ogbg-molpcba dataset in a convenient format.""" import jax @@ -33,8 +32,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): data_dir=data_dir) if should_shuffle: - dataset = dataset.shuffle( - seed=dataset_data_rng, buffer_size=2 ** 15) + dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15) dataset = dataset.repeat() return dataset @@ -114,8 +112,10 @@ def _get_batch_iterator(dataset_iter, batch_size, num_shards=None): max_n_graphs = batch_size jraph_iter = map(_to_jraph, dataset_iter) - batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1, - max_n_edges, max_n_graphs + 1) + batched_iter = jraph.dynamically_batch(jraph_iter, + max_n_nodes + 1, + max_n_edges, + max_n_graphs + 1) count = 0 graphs_shards = [] @@ -137,6 +137,7 @@ def _get_batch_iterator(dataset_iter, batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: + def f(x): return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) @@ -151,7 +152,7 @@ def f(x): weights_shards = [] -def get_dataset_iter(split, data_rng, data_dir, batch_size): +def get_dataset_iter(split, data_rng, data_dir, batch_size): ds = _load_dataset( split, should_shuffle=(split == 'train'), diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py index 5197c1345..7175e0c25 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py @@ -9,7 +9,9 @@ from sklearn.metrics import average_precision_score -def predictions_match_labels(*, logits: jnp.ndarray, labels: jnp.ndarray, +def predictions_match_labels(*, + logits: jnp.ndarray, + labels: jnp.ndarray, **kwargs) -> jnp.ndarray: """Returns a binary array indicating where predictions match the labels.""" del kwargs # Unused. @@ -54,4 +56,3 @@ class EvalMetrics(metrics.Collection): accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') mean_average_precision: MeanAveragePrecision - diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py index 88aa7b1a9..79682b31e 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py @@ -56,22 +56,14 @@ def __call__(self, graph, train): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( - update_edge_fn=_make_mlp( - self.hidden_dims, - dropout=dropout), - update_node_fn=_make_mlp( - self.hidden_dims, - dropout=dropout), - update_global_fn=_make_mlp( - self.hidden_dims, - dropout=dropout)) + update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), + update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), + update_global_fn=_make_mlp(self.hidden_dims, dropout=dropout)) graph = net(graph) # Map globals to represent the final result - decoder = jraph.GraphMapFeatures( - embed_global_fn=nn.Dense(self.num_outputs)) + decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs)) graph = decoder(graph) return graph.globals - diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py index 690c0eaf2..b33cf3fe4 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py @@ -23,29 +23,26 @@ def __init__(self): self._init_graphs = None self._model = models.GNN() - def _build_iterator( - self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - batch_size: int): - dataset_iter = input_pipeline.get_dataset_iter( - split, - data_rng, - data_dir, - batch_size) + def _build_iterator(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int): + dataset_iter = input_pipeline.get_dataset_iter(split, + data_rng, + data_dir, + batch_size) if self._init_graphs is None: init_graphs, _, _ = next(dataset_iter) # Unreplicate the iterator that has the leading dim for pmapping. self._init_graphs = jax.tree_map(lambda x: x[0], init_graphs) return dataset_iter - def build_input_queue( - self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - batch_size: int): + def build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int): return self._build_iterator(data_rng, split, data_dir, batch_size) @property @@ -64,24 +61,22 @@ def model_params_types(self): def is_output_params(self, param_key: spec.ParameterKey) -> bool: pass - def preprocess_for_train( - self, - selected_raw_input_batch: spec.Tensor, - selected_label_batch: spec.Tensor, - train_mean: spec.Tensor, - train_stddev: spec.Tensor, - rng: spec.RandomState) -> spec.Tensor: + def preprocess_for_train(self, + selected_raw_input_batch: spec.Tensor, + selected_label_batch: spec.Tensor, + train_mean: spec.Tensor, + train_stddev: spec.Tensor, + rng: spec.RandomState) -> spec.Tensor: del train_mean del train_stddev del rng return selected_raw_input_batch, selected_label_batch - def preprocess_for_eval( - self, - raw_input_batch: spec.Tensor, - raw_label_batch: spec.Tensor, - train_mean: spec.Tensor, - train_stddev: spec.Tensor) -> spec.Tensor: + def preprocess_for_eval(self, + raw_input_batch: spec.Tensor, + raw_label_batch: spec.Tensor, + train_mean: spec.Tensor, + train_stddev: spec.Tensor) -> spec.Tensor: del train_mean del train_stddev return raw_input_batch, raw_label_batch @@ -90,24 +85,21 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: if self._init_graphs is None: raise ValueError( 'This should not happen, workload.build_input_queue() should be ' - 'called before workload.init_model_fn()!' - ) + 'called before workload.init_model_fn()!') rng, params_rng, dropout_rng = jax.random.split(rng, 3) init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params = init_fn( - {'params': params_rng, 'dropout': dropout_rng}, self._init_graphs) + params = init_fn({'params': params_rng, 'dropout': dropout_rng}, + self._init_graphs) params = params['params'] - self._param_shapes = jax.tree_map( - lambda x: spec.ShapeTuple(x.shape), - params) + self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), + params) return jax_utils.replicate(params), None # Keep this separate from the loss function in order to support optimizers # that use the logits. - def output_activation_fn( - self, - logits_batch: spec.Tensor, - loss_type: spec.LossType) -> spec.Tensor: + def output_activation_fn(self, + logits_batch: spec.Tensor, + loss_type: spec.LossType) -> spec.Tensor: pass @property @@ -126,18 +118,16 @@ def model_fn( del update_batch_norm # No BN in the GNN model. assert model_state is None train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply( - {'params': params}, - augmented_and_preprocessed_input_batch, - rngs={'dropout': rng}, - train=train) + logits = self._model.apply({'params': params}, + augmented_and_preprocessed_input_batch, + rngs={'dropout': rng}, + train=train) return logits, None - def _binary_cross_entropy_with_mask( - self, - labels: jnp.ndarray, - logits: jnp.ndarray, - mask: jnp.ndarray) -> jnp.ndarray: + def _binary_cross_entropy_with_mask(self, + labels: jnp.ndarray, + logits: jnp.ndarray, + mask: jnp.ndarray) -> jnp.ndarray: """Binary cross entropy loss for logits, with masked elements.""" assert logits.shape == labels.shape == mask.shape assert len(logits.shape) == 2 @@ -151,8 +141,7 @@ def _binary_cross_entropy_with_mask( positive_logits = (logits >= 0) relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) - return relu_logits - (logits * labels) + ( - jnp.log(1 + jnp.exp(-abs_logits))) + return relu_logits - (logits * labels) + (jnp.log(1 + jnp.exp(-abs_logits))) # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @@ -186,12 +175,11 @@ def _eval_batch(self, params, graphs, labels, masks, model_state, rng): update_batch_norm=False) return self._eval_metric(labels, logits, masks) - def eval_model( - self, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str): + def eval_model(self, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) eval_per_core_batch_size = 1024 @@ -209,10 +197,15 @@ def eval_model( # Loop over graph batches in eval dataset. for _ in range(num_val_steps): graphs, labels, masks = next(self._eval_iterator) - batch_metrics = self._eval_batch( - params, graphs, labels, masks, model_state, model_rng) - total_metrics = (batch_metrics if total_metrics is None - else total_metrics.merge(batch_metrics)) + batch_metrics = self._eval_batch(params, + graphs, + labels, + masks, + model_state, + model_rng) + total_metrics = ( + batch_metrics + if total_metrics is None else total_metrics.merge(batch_metrics)) if total_metrics is None: return {} return {k: float(v) for k, v in total_metrics.reduce().compute().items()} diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogb/ogb_jax/submission.py index 23b71ba39..07ae4dee0 100644 --- a/baselines/ogb/ogb_jax/submission.py +++ b/baselines/ogb/ogb_jax/submission.py @@ -15,18 +15,17 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state( - workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparamters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparamters, + rng: spec.RandomState) -> spec.OptimizerState: """Creates an Adam optimizer.""" del model_params del model_state del rng - params_zeros_like = jax.tree_map( - lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) opt_init_fn, opt_update_fn = opt_init_fn, opt_update_fn = optax.adam( learning_rate=hyperparameters.learning_rate) optimizer_state = opt_init_fn(params_zeros_like) diff --git a/submission_runner.py b/submission_runner.py index d209e2a71..0f3638682 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,7 +31,6 @@ # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') - # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = "algorithmic_efficiency/workloads/" @@ -47,12 +46,14 @@ 'imagenet_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'imagenet/imagenet_jax/workload.py', - 'workload_class_name': 'ImagenetWorkload', + 'workload_class_name': + 'ImagenetWorkload', }, 'imagenet_pytorch': { 'workload_path': BASE_WORKLOADS_DIR + 'imagenet/imagenet_pytorch/workload.py', - 'workload_class_name': 'ImagenetWorkload', + 'workload_class_name': + 'ImagenetWorkload', }, 'ogb_jax': { 'workload_path': BASE_WORKLOADS_DIR + 'ogb/ogb_jax/workload.py', @@ -65,7 +66,8 @@ 'librispeech_pytorch': { 'workload_path': BASE_WORKLOADS_DIR + 'librispeech/librispeech_pytorch/workload.py', - 'workload_class_name': 'LibriSpeechWorkload', + 'workload_class_name': + 'LibriSpeechWorkload', } } @@ -235,14 +237,13 @@ def train_once(workload: spec.Workload, return accumulated_submission_time, metrics -def score_submission_on_workload( - workload: spec.Workload, - workload_name: str, - submission_path: str, - data_dir: str, - tuning_ruleset: str, - tuning_search_space: Optional[str] = None, - num_tuning_trials: Optional[int] = None): +def score_submission_on_workload(workload: spec.Workload, + workload_name: str, + submission_path: str, + data_dir: str, + tuning_ruleset: str, + tuning_search_space: Optional[str] = None, + num_tuning_trials: Optional[int] = None): # Remove the trailing '.py' and convert the filepath to a Python module. submission_module_path = _convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) From b067b00f116595e84b6b3c6a3eab1cf84df7f6a8 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Wed, 2 Mar 2022 01:52:24 -0500 Subject: [PATCH 43/46] moving TF to core deps, because we use it for some pytorch input pipelines? --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index f49fc63ab..b8cf3eb23 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,8 @@ install_requires = numpy>=1.19.2 pandas>=1.3.1 six>=1.15.0 + tensorflow-cpu==2.5.0 + tensorflow_datasets==4.4.0 python_requires = >=3.7 [options.extras_require] @@ -65,8 +67,6 @@ jax_core_deps = flax==0.3.5 jax==0.2.17 optax==0.0.9 - tensorflow-cpu==2.5.0 - tensorflow_datasets==4.4.0 # JAX CPU jax_cpu = From 58e0f9546cc13f6966786444081b5ace49e8c326 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Tue, 8 Mar 2022 12:44:45 -0500 Subject: [PATCH 44/46] renaming OGB to OGBG like the dataset. fixing and documenting we return the global batch size. returning per-example losses in OGBG. --- .../workloads/{ogb => ogbg}/__init__.py | 0 .../{ogb/ogb_jax => ogbg/ogbg_jax}/README.md | 0 .../ogb_jax => ogbg/ogbg_jax}/__init__.py | 0 .../ogbg_jax}/input_pipeline.py | 16 ++++++------- .../{ogb/ogb_jax => ogbg/ogbg_jax}/metrics.py | 0 .../{ogb/ogb_jax => ogbg/ogbg_jax}/models.py | 0 .../ogb_jax => ogbg/ogbg_jax}/workload.py | 23 +++++++++---------- .../workloads/{ogb => ogbg}/workload.py | 2 +- baselines/imagenet/imagenet_jax/submission.py | 1 + .../imagenet/imagenet_pytorch/submission.py | 1 + .../librispeech_pytorch/submission.py | 1 + baselines/mnist/mnist_jax/submission.py | 1 + baselines/mnist/mnist_pytorch/submission.py | 1 + .../ogb_jax => ogbg/ogbg_jax}/submission.py | 9 ++++---- .../ogbg_jax}/tuning_search_space.json | 0 baselines/wmt/wmt_jax/submission.py | 1 + 16 files changed, 31 insertions(+), 25 deletions(-) rename algorithmic_efficiency/workloads/{ogb => ogbg}/__init__.py (100%) rename algorithmic_efficiency/workloads/{ogb/ogb_jax => ogbg/ogbg_jax}/README.md (100%) rename algorithmic_efficiency/workloads/{ogb/ogb_jax => ogbg/ogbg_jax}/__init__.py (100%) rename algorithmic_efficiency/workloads/{ogb/ogb_jax => ogbg/ogbg_jax}/input_pipeline.py (91%) rename algorithmic_efficiency/workloads/{ogb/ogb_jax => ogbg/ogbg_jax}/metrics.py (100%) rename algorithmic_efficiency/workloads/{ogb/ogb_jax => ogbg/ogbg_jax}/models.py (100%) rename algorithmic_efficiency/workloads/{ogb/ogb_jax => ogbg/ogbg_jax}/workload.py (91%) rename algorithmic_efficiency/workloads/{ogb => ogbg}/workload.py (96%) rename baselines/{ogb/ogb_jax => ogbg/ogbg_jax}/submission.py (93%) rename baselines/{ogb/ogb_jax => ogbg/ogbg_jax}/tuning_search_space.json (100%) diff --git a/algorithmic_efficiency/workloads/ogb/__init__.py b/algorithmic_efficiency/workloads/ogbg/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogb/__init__.py rename to algorithmic_efficiency/workloads/ogbg/__init__.py diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/README.md b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/README.md similarity index 100% rename from algorithmic_efficiency/workloads/ogb/ogb_jax/README.md rename to algorithmic_efficiency/workloads/ogbg/ogbg_jax/README.md diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/__init__.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogb/ogb_jax/__init__.py rename to algorithmic_efficiency/workloads/ogbg/ogbg_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py similarity index 91% rename from algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py rename to algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py index e0299122f..b19897661 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py @@ -85,7 +85,7 @@ def _get_weights_by_nan_and_padding(labels, padding_mask): return replaced_labels, weights -def _get_batch_iterator(dataset_iter, batch_size, num_shards=None): +def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): """Turns a per-example iterator into a batched iterator. Constructs the batch from num_shards smaller batches, so that we can easily @@ -95,7 +95,7 @@ def _get_batch_iterator(dataset_iter, batch_size, num_shards=None): Args: dataset_iter: The TFDS dataset iterator. - batch_size: How many average-sized graphs go into the batch. + global_batch_size: How many average-sized graphs go into the batch. num_shards: How many devices we should be able to shard the batch into. Yields: Batch in the init2winit format. Each field is a list of num_shards separate @@ -105,11 +105,11 @@ def _get_batch_iterator(dataset_iter, batch_size, num_shards=None): num_shards = jax.device_count() # We will construct num_shards smaller batches and then put them together. - batch_size /= num_shards + global_batch_size /= num_shards - max_n_nodes = AVG_NODES_PER_GRAPH * batch_size - max_n_edges = AVG_EDGES_PER_GRAPH * batch_size - max_n_graphs = batch_size + max_n_nodes = AVG_NODES_PER_GRAPH * global_batch_size + max_n_edges = AVG_EDGES_PER_GRAPH * global_batch_size + max_n_graphs = global_batch_size jraph_iter = map(_to_jraph, dataset_iter) batched_iter = jraph.dynamically_batch(jraph_iter, @@ -152,10 +152,10 @@ def f(x): weights_shards = [] -def get_dataset_iter(split, data_rng, data_dir, batch_size): +def get_dataset_iter(split, data_rng, data_dir, global_batch_size): ds = _load_dataset( split, should_shuffle=(split == 'train'), data_rng=data_rng, data_dir=data_dir) - return _get_batch_iterator(iter(ds), batch_size) + return _get_batch_iterator(iter(ds), global_batch_size) diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/metrics.py similarity index 100% rename from algorithmic_efficiency/workloads/ogb/ogb_jax/metrics.py rename to algorithmic_efficiency/workloads/ogbg/ogbg_jax/metrics.py diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/ogb/ogb_jax/models.py rename to algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py diff --git a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py similarity index 91% rename from algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py rename to algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index b33cf3fe4..352c7862f 100644 --- a/algorithmic_efficiency/workloads/ogb/ogb_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -9,13 +9,13 @@ from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogb.ogb_jax import input_pipeline -from algorithmic_efficiency.workloads.ogb.ogb_jax import metrics -from algorithmic_efficiency.workloads.ogb.ogb_jax import models -from algorithmic_efficiency.workloads.ogb.workload import OGB +from algorithmic_efficiency.workloads.ogbg.ogbg_jax import input_pipeline +from algorithmic_efficiency.workloads.ogbg.ogbg_jax import metrics +from algorithmic_efficiency.workloads.ogbg.ogbg_jax import models +from algorithmic_efficiency.workloads.ogbg.workload import OGBG -class OGBWorkload(OGB): +class OGBGWorkload(OGBG): def __init__(self): self._eval_iterator = None @@ -150,13 +150,13 @@ def loss_fn( label_batch: spec.Tensor, logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor]) -> spec.Tensor: # differentiable - loss = self._binary_cross_entropy_with_mask( + per_example_losses = self._binary_cross_entropy_with_mask( labels=label_batch, logits=logits_batch, mask=mask_batch) - mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) - return mean_loss + return per_example_losses def _eval_metric(self, labels, logits, masks): - loss = self.loss_fn(labels, logits, masks) + per_example_losses = self.loss_fn(labels, logits, masks) + loss = jnp.sum(jnp.where(masks, per_example_losses, 0)) / jnp.sum(masks) return metrics.EvalMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=masks) @@ -182,17 +182,16 @@ def eval_model(self, data_dir: str): """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) - eval_per_core_batch_size = 1024 + total_eval_batch_size = 8192 if self._eval_iterator is None: self._eval_iterator = self._build_iterator( - data_rng, 'validation', data_dir, batch_size=eval_per_core_batch_size) + data_rng, 'validation', data_dir, batch_size=total_eval_batch_size) # Note that this effectively stores the entire val dataset in memory. self._eval_iterator = itertools.cycle(self._eval_iterator) total_metrics = None # Both val and test have the same (prime) number of examples. num_val_examples = 43793 - total_eval_batch_size = eval_per_core_batch_size * jax.local_device_count() num_val_steps = num_val_examples // total_eval_batch_size + 1 # Loop over graph batches in eval dataset. for _ in range(num_val_steps): diff --git a/algorithmic_efficiency/workloads/ogb/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/ogb/workload.py rename to algorithmic_efficiency/workloads/ogbg/workload.py index 956279655..f8a0ee2e3 100644 --- a/algorithmic_efficiency/workloads/ogb/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -1,7 +1,7 @@ from algorithmic_efficiency import spec -class OGB(spec.Workload): +class OGBG(spec.Workload): def has_reached_goal(self, eval_result: float) -> bool: return eval_result['mean_average_precision'] > self.target_value diff --git a/baselines/imagenet/imagenet_jax/submission.py b/baselines/imagenet/imagenet_jax/submission.py index 8d67b2272..e5ff5adf4 100644 --- a/baselines/imagenet/imagenet_jax/submission.py +++ b/baselines/imagenet/imagenet_jax/submission.py @@ -13,6 +13,7 @@ def get_batch_size(workload_name): + # Return the global batch size. del workload_name return 128 diff --git a/baselines/imagenet/imagenet_pytorch/submission.py b/baselines/imagenet/imagenet_pytorch/submission.py index 5b52e0034..57412b02a 100644 --- a/baselines/imagenet/imagenet_pytorch/submission.py +++ b/baselines/imagenet/imagenet_pytorch/submission.py @@ -10,6 +10,7 @@ def get_batch_size(workload_name): + # Return the global batch size. batch_sizes = {'imagenet_pytorch': 128} return batch_sizes[workload_name] diff --git a/baselines/librispeech/librispeech_pytorch/submission.py b/baselines/librispeech/librispeech_pytorch/submission.py index de666bc28..9f24cfa8e 100644 --- a/baselines/librispeech/librispeech_pytorch/submission.py +++ b/baselines/librispeech/librispeech_pytorch/submission.py @@ -10,6 +10,7 @@ def get_batch_size(workload_name): + # Return the global batch size. batch_sizes = {"librispeech_pytorch": 8} return batch_sizes[workload_name] diff --git a/baselines/mnist/mnist_jax/submission.py b/baselines/mnist/mnist_jax/submission.py index b71956ef2..8ca9c0e49 100644 --- a/baselines/mnist/mnist_jax/submission.py +++ b/baselines/mnist/mnist_jax/submission.py @@ -12,6 +12,7 @@ def get_batch_size(workload_name): + # Return the global batch size. batch_sizes = {'mnist_jax': 1024} return batch_sizes[workload_name] diff --git a/baselines/mnist/mnist_pytorch/submission.py b/baselines/mnist/mnist_pytorch/submission.py index a6563fdc8..cdc755abd 100644 --- a/baselines/mnist/mnist_pytorch/submission.py +++ b/baselines/mnist/mnist_pytorch/submission.py @@ -10,6 +10,7 @@ def get_batch_size(workload_name): + # Return the global batch size. batch_sizes = {'mnist_pytorch': 1024} return batch_sizes[workload_name] diff --git a/baselines/ogb/ogb_jax/submission.py b/baselines/ogbg/ogbg_jax/submission.py similarity index 93% rename from baselines/ogb/ogb_jax/submission.py rename to baselines/ogbg/ogbg_jax/submission.py index 07ae4dee0..522a1fdd4 100644 --- a/baselines/ogb/ogb_jax/submission.py +++ b/baselines/ogbg/ogbg_jax/submission.py @@ -10,8 +10,8 @@ def get_batch_size(workload_name): - # Return the per-device batch size. - batch_sizes = {'ogb_jax': 256} + # Return the global batch size. + batch_sizes = {'ogb_jax': 2048} return batch_sizes[workload_name] @@ -52,8 +52,9 @@ def loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True) - loss = workload.loss_fn(label_batch, logits_batch, mask_batch) - mean_loss = jnp.sum(jnp.where(mask_batch, loss, 0)) / jnp.sum(mask_batch) + per_example_losses = workload.loss_fn(label_batch, logits_batch, mask_batch) + mean_loss = (jnp.sum(jnp.where(mask_batch, per_example_losses, 0)) / + jnp.sum(mask_batch)) return mean_loss, new_model_state grad_fn = jax.value_and_grad(loss_fn, has_aux=True) diff --git a/baselines/ogb/ogb_jax/tuning_search_space.json b/baselines/ogbg/ogbg_jax/tuning_search_space.json similarity index 100% rename from baselines/ogb/ogb_jax/tuning_search_space.json rename to baselines/ogbg/ogbg_jax/tuning_search_space.json diff --git a/baselines/wmt/wmt_jax/submission.py b/baselines/wmt/wmt_jax/submission.py index 53b53a4f9..8aa9036bf 100644 --- a/baselines/wmt/wmt_jax/submission.py +++ b/baselines/wmt/wmt_jax/submission.py @@ -16,6 +16,7 @@ def get_batch_size(workload_name): + # Return the global batch size. batch_sizes = {"wmt_jax": 16} return batch_sizes[workload_name] From f91d4f14d58f9bf1340dfe347346b2c467e24488 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Tue, 8 Mar 2022 12:47:22 -0500 Subject: [PATCH 45/46] lint --- baselines/ogbg/ogbg_jax/submission.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/baselines/ogbg/ogbg_jax/submission.py b/baselines/ogbg/ogbg_jax/submission.py index 522a1fdd4..5437e0bdc 100644 --- a/baselines/ogbg/ogbg_jax/submission.py +++ b/baselines/ogbg/ogbg_jax/submission.py @@ -53,8 +53,9 @@ def loss_fn(params): rng, update_batch_norm=True) per_example_losses = workload.loss_fn(label_batch, logits_batch, mask_batch) - mean_loss = (jnp.sum(jnp.where(mask_batch, per_example_losses, 0)) / - jnp.sum(mask_batch)) + mean_loss = ( + jnp.sum(jnp.where(mask_batch, per_example_losses, 0)) / + jnp.sum(mask_batch)) return mean_loss, new_model_state grad_fn = jax.value_and_grad(loss_fn, has_aux=True) From c2ba214faf34471e983f82ed74734c992f0966f4 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Tue, 8 Mar 2022 13:52:29 -0500 Subject: [PATCH 46/46] renaming vars in input pipeline to make more sense --- .../workloads/ogbg/ogbg_jax/input_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py index b19897661..9435c3769 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/input_pipeline.py @@ -105,11 +105,11 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): num_shards = jax.device_count() # We will construct num_shards smaller batches and then put them together. - global_batch_size /= num_shards + per_device_batch_size = global_batch_size / num_shards - max_n_nodes = AVG_NODES_PER_GRAPH * global_batch_size - max_n_edges = AVG_EDGES_PER_GRAPH * global_batch_size - max_n_graphs = global_batch_size + max_n_nodes = AVG_NODES_PER_GRAPH * per_device_batch_size + max_n_edges = AVG_EDGES_PER_GRAPH * per_device_batch_size + max_n_graphs = per_device_batch_size jraph_iter = map(_to_jraph, dataset_iter) batched_iter = jraph.dynamically_batch(jraph_iter,