Skip to content

Commit

Permalink
Merge pull request #88 from runame/graph-pytorch
Browse files Browse the repository at this point in the history
OGBG PyTorch Workload
  • Loading branch information
znado authored Jul 8, 2022
2 parents 14a202e + addb65d commit 813e328
Show file tree
Hide file tree
Showing 25 changed files with 660 additions and 147 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
pip install .[jax_cpu]
pip install .[pytorch_cpu]
pip install .[wmt]
pip install .[ogbg]
- name: Run pytest
run: |
pytest -vx tests/version_test.py
Expand Down
12 changes: 12 additions & 0 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
from typing import Tuple

import torch


def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
n_gpus = torch.cuda.device_count()
return use_pytorch_ddp, rank, device, n_gpus
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import math
import os
import random
from typing import Dict, Tuple

Expand All @@ -17,15 +16,13 @@
from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
import algorithmic_efficiency.random_utils as prng
from algorithmic_efficiency.workloads.cifar.workload import BaseCifarWorkload
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \
resnet18

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


class CifarWorkload(BaseCifarWorkload):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
import algorithmic_efficiency.random_utils as prng
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.input_pipeline import \
RandomMask
Expand All @@ -27,10 +28,7 @@
from algorithmic_efficiency.workloads.fastmri.workload import \
BaseFastMRIWorkload

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


def ssim(gt: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
import algorithmic_efficiency.random_utils as prng
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \
resnet50
from algorithmic_efficiency.workloads.imagenet_resnet.workload import \
BaseImagenetResNetWorkload

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


class ImagenetResNetWorkload(BaseImagenetResNetWorkload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""ImageNet ViT workload implemented in PyTorch."""

import contextlib
import os
from typing import Dict, Tuple

import torch
from torch.nn.parallel import DistributedDataParallel as DDP

from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \
ImagenetResNetWorkload
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch import \
Expand All @@ -17,10 +17,7 @@
from algorithmic_efficiency.workloads.imagenet_vit.workload import \
decode_variant

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


# Make sure we inherit from the ViT base workload first.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""MNIST workload implemented in PyTorch."""
from collections import OrderedDict
import contextlib
import os
from typing import Any, Dict, Tuple

import torch
Expand All @@ -15,12 +14,10 @@
from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


class _Model(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir):
read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng)
dataset = tfds.load(
'ogbg_molpcba',
split='train' if split == 'eval_train' else split,
split=split,
shuffle_files=should_shuffle,
read_config=read_config,
data_dir=data_dir)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Forked from Flax example which can be found here:
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py
from typing import Any

from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.metrics import average_precision_score
import torch
import torch.distributed as dist

from algorithmic_efficiency.pytorch_utils import pytorch_setup

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


def predictions_match_labels(*,
Expand All @@ -26,9 +33,21 @@ class MeanAveragePrecision(

def compute(self):
# Matches the official OGB evaluation scheme for mean average precision.
labels = self.values['labels']
logits = self.values['logits']
mask = self.values['mask']
values = super().compute()
labels = values['labels']
logits = values['logits']
mask = values['mask']

if USE_PYTORCH_DDP:
# Sync labels, logits, and masks across devices.
all_values = [labels, logits, mask]
for idx, array in enumerate(all_values):
tensor = torch.as_tensor(array, device=DEVICE)
# Assumes that the tensors on all devices have the same shape.
all_tensors = [torch.zeros_like(tensor) for _ in range(N_GPUS)]
dist.all_gather(all_tensors, tensor)
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values

mask = mask.astype(np.bool)

Expand All @@ -51,8 +70,25 @@ def compute(self):
return np.nanmean(average_precisions)


class AverageDDP(metrics.Average):
"""Supports syncing metrics for PyTorch distributed data parallel (DDP)."""

def compute(self) -> Any:
if USE_PYTORCH_DDP:
# Sync counts across devices.
total_tensor = torch.as_tensor(np.asarray(self.total), device=DEVICE)
count_tensor = torch.as_tensor(np.asarray(self.count), device=DEVICE)
dist.all_reduce(total_tensor)
dist.all_reduce(count_tensor)
# Hacky way to avoid FrozenInstanceError
# (https://docs.python.org/3/library/dataclasses.html#frozen-instances).
object.__setattr__(self, 'total', total_tensor.cpu().numpy())
object.__setattr__(self, 'count', count_tensor.cpu().numpy())
return super().compute()


@flax.struct.dataclass
class EvalMetrics(metrics.Collection):
accuracy: metrics.Average.from_fun(predictions_match_labels)
loss: metrics.Average.from_output('loss')
accuracy: AverageDDP.from_fun(predictions_match_labels)
loss: AverageDDP.from_output('loss')
mean_average_precision: MeanAveragePrecision
103 changes: 7 additions & 96 deletions algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,25 @@
"""OGB workload implemented in Jax."""
"""OGBG workload implemented in Jax."""
import functools
import itertools
import math
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple

from flax import jax_utils
import jax
import jax.numpy as jnp
import jraph

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import random_utils as prng
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.ogbg.ogbg_jax import input_pipeline
from algorithmic_efficiency.workloads.ogbg.ogbg_jax import metrics
from algorithmic_efficiency.workloads.ogbg import metrics
from algorithmic_efficiency.workloads.ogbg.ogbg_jax import models
from algorithmic_efficiency.workloads.ogbg.workload import BaseOgbgWorkload


class OgbgWorkload(BaseOgbgWorkload):

def __init__(self):
self._eval_iters = {}
self._param_shapes = None
self._param_types = None
self._num_outputs = 128
super().__init__()
self._model = models.GNN(self._num_outputs)

def build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int):
dataset_iter = input_pipeline.get_dataset_iter(split,
data_rng,
data_dir,
global_batch_size)
return dataset_iter

@property
def param_shapes(self):
if self._param_shapes is None:
raise ValueError(
'This should not happen, workload.init_model_fn() should be called '
'before workload.param_shapes!')
return self._param_shapes

@property
def model_params_types(self):
if self._param_shapes is None:
Expand All @@ -57,19 +31,14 @@ def model_params_types(self):
self._param_shapes.unfreeze())
return self._param_types

# Return whether or not a key in spec.ParameterContainer is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
rng, params_rng, dropout_rng = jax.random.split(rng, 3)
init_fn = jax.jit(functools.partial(self._model.init, train=False))
fake_batch = jraph.GraphsTuple(
n_node=jnp.asarray([1]),
n_edge=jnp.asarray([1]),
nodes=jnp.ones((1, 3)),
edges=jnp.ones((1, 7)),
nodes=jnp.ones((1, 9)),
edges=jnp.ones((1, 3)),
globals=jnp.zeros((1, self._num_outputs)),
senders=jnp.asarray([0]),
receivers=jnp.asarray([0]))
Expand All @@ -79,17 +48,6 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params)
return jax_utils.replicate(params), None

# Keep this separate from the loss function in order to support optimizers
# that use the logits.
def output_activation_fn(self,
logits_batch: spec.Tensor,
loss_type: spec.LossType) -> spec.Tensor:
pass

@property
def loss_type(self):
return spec.LossType.SOFTMAX_CROSS_ENTROPY

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down Expand Up @@ -133,17 +91,6 @@ def _binary_cross_entropy_with_mask(self,
abs_logits = jnp.where(positive_logits, logits, -logits)
return relu_logits - (logits * labels) + (jnp.log(1 + jnp.exp(-abs_logits)))

# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
self,
label_batch: spec.Tensor,
logits_batch: spec.Tensor,
mask_batch: Optional[spec.Tensor]) -> spec.Tensor: # differentiable
per_example_losses = self._binary_cross_entropy_with_mask(
labels=label_batch, logits=logits_batch, mask=mask_batch)
return per_example_losses

def _eval_metric(self, labels, logits, masks):
per_example_losses = self.loss_fn(labels, logits, masks)
loss = jnp.sum(jnp.where(masks, per_example_losses, 0)) / jnp.sum(masks)
Expand All @@ -156,40 +103,4 @@ def _eval_metric(self, labels, logits, masks):
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,))
def _eval_batch(self, params, batch, model_state, rng):
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False)
return self._eval_metric(batch['targets'], logits, batch['weights'])

def _eval_model_on_split(self,
split: str,
num_examples: int,
global_batch_size: int,
params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
data_dir: str) -> Dict[str, float]:
"""Run a full evaluation of the model."""
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
eval_iter = self.build_input_queue(
data_rng, split, data_dir, global_batch_size=global_batch_size)
# Note that this stores the entire val dataset in memory.
self._eval_iters[split] = itertools.cycle(eval_iter)

total_metrics = None
num_eval_steps = int(math.ceil(float(num_examples) / global_batch_size))
# Loop over graph batches in eval dataset.
for _ in range(num_eval_steps):
batch = next(self._eval_iters[split])
batch_metrics = self._eval_batch(params, batch, model_state, model_rng)
total_metrics = (
batch_metrics
if total_metrics is None else total_metrics.merge(batch_metrics))
if total_metrics is None:
return {}
return {k: float(v) for k, v in total_metrics.reduce().compute().items()}
return super()._eval_batch(params, batch, model_state, rng)
Empty file.
Loading

0 comments on commit 813e328

Please sign in to comment.