diff --git a/README.md b/README.md index 216354e73..a4536c35e 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ See instructions [here](https://github.com/NVIDIA/nvidia-docker). ### Running Docker Container (Interactive) To use the Docker container as an interactive virtual environment, you can run a container mounted to your local data and code directories and execute the `bash` program. This may be useful if you are in the process of developing a submission. -1. Run detached Docker Container. The container_id will be printed if the container is run successfully. +1. Run detached Docker Container. The `container_id` will be printed if the container is running successfully. ```bash docker run -t -d \ -v $HOME/data/:/data/ \ @@ -122,7 +122,7 @@ To use the Docker container as an interactive virtual environment, you can run a -v $HOME/algorithmic-efficiency:/algorithmic-efficiency \ --gpus all \ --ipc=host \ - + \ --keep_container_alive true ``` 2. Open a bash terminal diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 38744716b..14e3c7c6c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np( inputs = batch['inputs'] current_batch_size = inputs[0].shape[0] if isinstance( inputs, tuple) else inputs.shape[0] + if global_batch_size is not None: + assert global_batch_size >= current_batch_size, \ + 'global_batch_size must be larger than or equal to current_batch_size.' + # Always pad to global_batch_size if it is provided. + pad_to_global_batch_size = global_batch_size > current_batch_size + else: + pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -50,8 +57,8 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0: - x = pad(x, pad_size, 'jax', padding_value=padding_value) + if remainder_size != 0 or pad_to_global_batch_size: + x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to # (local_device_count, per_device_batch_size, ...). @@ -61,21 +68,13 @@ def _prepare(x): return jax.tree_map(_prepare, batch) -def pad(tensor: spec.Tensor, +def pad(tensor: np.ndarray, pad_size: int, - framework: str, - padding_value: int = 0) -> spec.Tensor: - if len(tensor) > 1: + padding_value: int = 0) -> np.ndarray: + if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) - if framework == 'pytorch': - padding = torch.full( - pad_size, padding_value, dtype=tensor.dtype, device=tensor.device) - padded_tensor = torch.cat((tensor, padding), dim=0) - elif framework == 'jax': - padding = np.full(pad_size, padding_value, dtype=tensor.dtype) - padded_tensor = np.concatenate((tensor, padding), axis=0) - else: - raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.') + padding = np.full(pad_size, padding_value, dtype=tensor.dtype) + padded_tensor = np.concatenate((tensor, padding), axis=0) return padded_tensor diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index af2e61581..2b3cf86f6 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,6 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data +def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): + meta_data = get_meta_data(workload) + meta_data.update({'rng_seed': rng_seed}) + write_json(meta_file_name, meta_data) + + class MetricLogger(object): """Used to log all measurements during training. diff --git a/algorithmic_efficiency/param_utils.py b/algorithmic_efficiency/param_utils.py index 00c50ee4f..b430366b1 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algorithmic_efficiency/param_utils.py @@ -41,6 +41,10 @@ def pytorch_param_types( elif 'attn' in name or 'attention' in name: if 'bias' in name: param_types[name] = spec.ParameterType.ATTENTION_BIAS + elif 'in_proj' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'kv_proj' in name: + param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: param_types[name] = spec.ParameterType.ATTENTION_K elif 'q_proj' in name or 'query' in name: @@ -51,8 +55,6 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_OUT elif 'scale' in name: param_types[name] = spec.ParameterType.WEIGHT - elif 'in_proj_weight' in name: - param_types[name] = spec.ParameterType.ATTENTION_QKV else: raise ValueError(f'Unrecognized attention parameter: {name}.') elif 'bias' in name: diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index 0a1c1be79..fa2a1bee2 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -11,6 +11,13 @@ from typing import Dict, Generator, List, Optional, Tuple import numpy as np +import torch + + +def _get_monotonic_time() -> float: + if torch.cuda.is_available() and torch.cuda.is_initialized(): + torch.cuda.synchronize() + return time.monotonic() class Profiler: @@ -20,7 +27,7 @@ def __init__(self, local_rank: Optional[int] = None) -> None: self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) - self.start_time = time.monotonic() + self.start_time = _get_monotonic_time() def set_local_rank(self, local_rank: int) -> None: self._local_rank = local_rank @@ -35,12 +42,12 @@ def start(self, action_name: str) -> None: if action_name in self.current_actions: raise ValueError( f'Attempted to start {action_name} which has already started.') - self.current_actions[action_name] = time.monotonic() + self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: if self.local_rank != 0: pass - end_time = time.monotonic() + end_time = _get_monotonic_time() if action_name not in self.current_actions: raise ValueError(f'Attempting to stop recording an action ' f'({action_name}) which was never started.') @@ -59,7 +66,7 @@ def profile(self, action_name: str) -> Generator: def _make_report( self ) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]: - total_duration = time.monotonic() - self.start_time + total_duration = _get_monotonic_time() - self.start_time report = [(str(a), float(np.mean(d)), float(np.std(d)), diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 570b7c55b..285983957 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -39,9 +39,10 @@ class ParameterType(enum.Enum): ATTENTION_V = 10 ATTENTION_OUT = 11 ATTENTION_QKV = 12 # This is used for implementations that fuse QKV together. - # We need to split this out because otherwise fused QKV models will have a - # different number of biases. - ATTENTION_BIAS = 13 + ATTENTION_KV = 13 # This is used for implementations that fuse KV together. + # We sometimes need to split this out because otherwise fused models will have + # a different number of biases. + ATTENTION_BIAS = 14 # Of course, Tensor knows its shape and dtype. diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 993d82c9d..55b68fb2f 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -233,7 +233,7 @@ def _eval_batch(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - return summed_loss + return summed_loss.to(dtype=torch.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 801716de7..ef971bb75 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -63,11 +63,11 @@ def num_eval_train_examples(self) -> int: @property def num_validation_examples(self) -> int: - return 89_000_000 + return 83_274_637 @property def num_test_examples(self) -> int: - return 89_274_637 + return 95_000_000 @property def train_mean(self): diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 0ce943b3b..b787785a1 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -1,7 +1,6 @@ import copy import math from typing import Any, Callable, Dict, Optional, Tuple, Union -import warnings import torch from torch import nn @@ -11,36 +10,31 @@ from torch.nn.init import xavier_uniform_ -def make_causal_mask(x: Tensor, - device: str = 'cuda:0', - dtype: torch.dtype = torch.float32) -> Tensor: +def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. Args: x: input array of shape `[batch..., len]` device: device to store the idxs - dtype: mask return dtype Returns: A `[batch..., len, len]` shaped causal attention mask. """ idxs = torch.broadcast_to( torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) - return torch.greater_equal(idxs.unsqueeze(-1), - idxs.unsqueeze(-2)).to(dtype=dtype) + return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) def make_src_mask(src, inputs_segmentation, nhead): """Utility for creating src mask and adjust it for PyTorch Transformer API.""" - src_mask = torch.mul((src > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)).to(dtype=torch.float32) + src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: src_mask = torch.logical_and( src_mask, torch.eq( inputs_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) + inputs_segmentation.unsqueeze(-2))) # Flip values and ensure numerical stability. src_mask = torch.repeat_interleave( torch.logical_not(src_mask), repeats=nhead, dim=0) @@ -59,27 +53,25 @@ def make_tgt_and_memory_mask(tgt, Transformer API.""" if not decode: tgt_mask = torch.logical_and( - torch.mul((tgt > 0).unsqueeze(-1), - (tgt > 0).unsqueeze(-2)).to(dtype=torch.float32), + torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), make_causal_mask(tgt, device=tgt.device)) - memory_mask = torch.mul((tgt > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)).to(dtype=torch.float32) + memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) else: tgt_mask = None memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)).to(dtype=torch.float32) + (src > 0).unsqueeze(-2)) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: tgt_mask = torch.logical_and( tgt_mask, torch.eq( targets_segmentation.unsqueeze(-1), - targets_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) + targets_segmentation.unsqueeze(-2))) memory_mask = torch.logical_and( memory_mask, torch.eq( targets_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) + inputs_segmentation.unsqueeze(-2))) # Flip values and ensure numerical stability. memory_mask = torch.repeat_interleave( torch.logical_not(memory_mask), repeats=nhead, dim=0) @@ -417,8 +409,7 @@ def forward( # TransformerEncoderLayer and TransformerDecoderLayer are taken from: # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py -# Only difference is using custom MultiheadAttention modules without bias and -# '_qkv_same_embed_dim' always set to 'False'. +# Main difference is the use of custom MultiheadAttention modules. class TransformerEncoderLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". @@ -437,22 +428,15 @@ class TransformerEncoderLayer(nn.Module): string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components (default=1e-6). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``True`` (batch, seq, feature). norm_first: if ``True``, layer norm is done prior to attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``True``. Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - Alternatively, when ``batch_first`` is ``True``: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, - batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ - __constants__ = ['batch_first', 'norm_first'] + __constants__ = ['norm_first'] def __init__(self, d_model: int = 1024, @@ -462,7 +446,6 @@ def __init__(self, attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, - batch_first: bool = True, norm_first: bool = True, device=None, dtype=None) -> None: @@ -471,8 +454,8 @@ def __init__(self, self.self_attn = MultiheadAttention( d_model, nhead, + self_attn=True, dropout_rate=attention_dropout_rate, - batch_first=batch_first, bias=False, **factory_kwargs) @@ -519,7 +502,7 @@ def forward(self, 'Only bool and floating types of key_padding_mask are supported') x = src if self.norm_first: - x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._sa_block(self.norm1(x), src_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) @@ -528,17 +511,8 @@ def forward(self, return x # Self-attention block: - def _sa_block(self, - x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor]) -> Tensor: - x = self.self_attn( - x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False)[0] + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask) return self.dropout1(x) # Feed forward block: @@ -547,7 +521,8 @@ def _ff_block(self, x: Tensor) -> Tensor: return self.dropout2(x) -# Modified to use cache for autoregressive decoding. +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. class TransformerDecoder(nn.Module): r"""TransformerDecoder is a stack of N decoder layers Args: @@ -630,7 +605,8 @@ def forward(self, return output -# Modified to use cache for autoregressive decoding. +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. class TransformerDecoderLayer(nn.Module): r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. @@ -650,24 +626,16 @@ class TransformerDecoderLayer(nn.Module): string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components (default=1e-6). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``True`` (batch, seq, feature). norm_first: if ``True``, layer norm is done prior to self attention, multihead attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``True``. Examples:: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - Alternatively, when ``batch_first`` is ``True``: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, - batch_first=True) >>> memory = torch.rand(32, 10, 512) >>> tgt = torch.rand(32, 20, 512) >>> out = decoder_layer(tgt, memory) """ - __constants__ = ['batch_first', 'norm_first'] + __constants__ = ['norm_first'] def __init__(self, d_model: int = 1024, @@ -677,7 +645,6 @@ def __init__(self, attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-6, - batch_first: bool = True, norm_first: bool = True, device=None, dtype=None) -> None: @@ -686,15 +653,15 @@ def __init__(self, self.self_attn = MultiheadAttention( d_model, nhead, + self_attn=True, dropout_rate=attention_dropout_rate, - batch_first=batch_first, bias=False, **factory_kwargs) self.multihead_attn = MultiheadAttention( d_model, nhead, + self_attn=False, dropout_rate=attention_dropout_rate, - batch_first=batch_first, bias=False, **factory_kwargs) @@ -746,7 +713,7 @@ def forward( # pylint: disable=arguments-renamed cache=cache, index=index) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask, None) + x = x + self._mha_block(self.norm2(x), memory, memory_mask) x = x + self._ff_block(self.norm3(x)) else: sa_out, cache = self._sa_block( @@ -757,7 +724,7 @@ def forward( # pylint: disable=arguments-renamed cache=cache, index=index) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask, None)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask)) x = self.norm3(x + self._ff_block(x)) return x, cache @@ -771,12 +738,9 @@ def _sa_block( # pylint: disable=arguments-renamed max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None) -> Any: - x, _, cache = self.self_attn( - x, - x, + x, cache = self.self_attn( x, attn_mask=attn_mask, - need_weights=False, decode=decode, max_len=max_len, cache=cache, @@ -784,18 +748,9 @@ def _sa_block( # pylint: disable=arguments-renamed return self.dropout1(x), cache # Multihead attention block: - def _mha_block(self, - x: Tensor, - mem: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor]) -> Tensor: - x = self.multihead_attn( - x, - mem, - mem, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False)[0] + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) return self.dropout2(x) # Feed forward block. @@ -804,12 +759,10 @@ def _ff_block(self, x: Tensor) -> Tensor: return self.dropout3(x) -# Only difference to standard PyTorch class is that 'self._qkv_same_embed_dim' -# is always set to 'False' and the use of a cache registered as a buffer for -# autoregressive decoding. -class MultiheadAttention(nn.MultiheadAttention): +class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information - from different representation subspaces. + from different representation subspaces. Supports self-attention and + encoder-decoder attention. See `Attention Is All You Need `_. .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O @@ -819,117 +772,75 @@ class MultiheadAttention(nn.MultiheadAttention): num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + self_attn: Whether self attention or encoder-decoder attention is used. + Default: ``True``. dropout_rate: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. - Default: ``True``. - add_bias_kv: If specified, adds bias to the key and value sequences at - dim=0. Default: ``False``. - add_zero_attn: If specified, adds a new batch of zeros to the key and value - sequences at dim=1. Default: ``False``. - kdim: Total number of features for keys. Default: ``None`` - (uses ``kdim=embed_dim``). - vdim: Total number of features for values. Default: ``None`` - (uses ``vdim=embed_dim``). - batch_first: If ``True``, then the input and output tensors are provided - as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + Default: ``False``. + device: The device of the module. + dtype: The dtype of the module. Examples:: >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + >>> attn_output, cache = multihead_attn(x) """ def __init__(self, - embed_dim, - num_heads, - dropout_rate=0., - bias=True, - add_bias_kv=False, - add_zero_attn=False, - kdim=None, - vdim=None, - batch_first=True, - device=None, - dtype=None) -> None: - super().__init__( - embed_dim, - num_heads, - dropout=dropout_rate, - bias=bias, - add_bias_kv=add_bias_kv, - add_zero_attn=add_zero_attn, - kdim=kdim, - vdim=vdim, - batch_first=batch_first, - device=device, - dtype=dtype) - # This is set to 'True' for kdim == vdim == embed_dim in the standard - # PyTorch class. - self._qkv_same_embed_dim = False + embed_dim: int, + num_heads: int, + self_attn: bool = True, + dropout_rate: float = 0., + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.self_attn = self_attn + self.dropout = dropout_rate + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, \ + 'embed_dim must be divisible by num_heads.' factory_kwargs = {'device': device, 'dtype': dtype} - self.q_proj_weight = nn.Parameter( - torch.empty((embed_dim, embed_dim), **factory_kwargs)) - self.k_proj_weight = nn.Parameter( - torch.empty((embed_dim, self.kdim), **factory_kwargs)) - self.v_proj_weight = nn.Parameter( - torch.empty((embed_dim, self.vdim), **factory_kwargs)) - self.register_parameter('in_proj_weight', None) + if self_attn: + # Self-attention. + self.in_proj = nn.Linear( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + else: + # Encoder-decoder attention. + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.kv_proj = nn.Linear( + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self._reset_parameters() def _reset_parameters(self): - if self._qkv_same_embed_dim: - xavier_uniform_(self.in_proj_weight) - else: - xavier_uniform_(self.q_proj_weight) - xavier_uniform_(self.k_proj_weight) - xavier_uniform_(self.v_proj_weight) - - if self.in_proj_bias is not None: - normal_(self.in_proj_bias, std=1e-6) - normal_(self.out_proj.bias, std=1e-6) - if self.bias_k is not None: - normal_(self.bias_k, std=1e-6) - if self.bias_v is not None: - normal_(self.bias_v, std=1e-6) + """Initiate parameters in the MultiheadAttention module.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + xavier_uniform_(module.weight) + if module.bias is not None: + normal_(module.bias, std=1e-6) def forward(self, - query: Tensor, - key: Tensor, - value: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, + x: Tensor, + mem: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None) -> Any: r""" Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, - :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` - when ``batch_first=True``, where :math:`L` is the target sequence - length, :math:`N` is the batch size, and :math:`E_q` is the query - embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, - :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` - when ``batch_first=True``, where :math:`S` is the source sequence - length, :math:`N` is the batch size, and :math:`E_k` is the key - embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, - :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` - when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the - value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: Dummy argument to make MultiheadAttention compatible - with standard PyTorch TransformerEncoder implementation. - need_weights: If specified, returns ``attn_output_weights`` in addition - to ``attn_outputs``.Default: ``True``. + x: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for self + attention mechanism. See "Attention Is All You Need" for more details. + mem: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for + encoder-decoder attention. See "Attention Is All You Need" for more + details. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the @@ -942,338 +853,111 @@ def forward(self, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. - average_attn_weights: If true, indicates that the returned - ``attn_weights`` should be averaged across heads. Otherwise, - ``attn_weights`` are provided separately per head. Note that this - flag only has an effect when ``need_weights=True``. Default: - ``True`` (i.e. average weights across heads) decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + cache: cache dictionary for autoregressive decoding. + index: index of the current decoding step, necessary for decoding cache. Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input - is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or - :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch - size, and :math:`E` is the embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. - If ``average_attn_weights=True``, returns attention weights averaged - across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the - target sequence length, and :math:`S` is the source sequence length. - If ``average_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched - or :math:`(N, \text{num\_heads}, L, S)`. - .. note:: - `batch_first` argument is ignored for unbatched inputs. + - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where + :math:`L` is the target sequence length, :math:`N` is the batch size, + and :math:`E` is the embedding dimension ``embed_dim``. + - **cache** - For autoregressive decoding. """ - del key_padding_mask - is_batched = query.dim() == 3 - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + # Shape: (batch size, sequence length, embedding dimensionality) + bsz, seq_len, embed_dim = x.size() + # In projection. + if self.self_attn: + q, k, v = self.in_proj(x).split(self.embed_dim, dim=2) + else: + q = self.q_proj(x) + k, v = self.kv_proj(mem).split(self.embed_dim, dim=2) + # This is 1 (!= seq_len) during autoreregressive decoding. + tgt_len = q.size(1) + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. name = f'decoder.layers.{index}.self_attn' loc_cache = cache[name] if decode and name in cache else None - - attn_output, attn_output_weights, loc_cache = multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_bias, self.bias_k, self.bias_v, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, need_weights=need_weights, attn_mask=attn_mask, - q_proj_weight=self.q_proj_weight, - k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, - average_attn_weights=average_attn_weights, - decode=decode, cache=loc_cache, max_len=max_len) + if decode: + if loc_cache is None: + loc_cache = { + 'cached_key': + torch.zeros((bsz, max_len, embed_dim), + dtype=k.dtype, + device=k.device), + 'cached_value': + torch.zeros((bsz, max_len, embed_dim), + dtype=v.dtype, + device=v.device), + 'cache_index': + torch.tensor(0, dtype=torch.long, device=k.device), + } + cached_key = loc_cache['cached_key'] + cached_value = loc_cache['cached_value'] + cache_index = loc_cache['cache_index'] + # Shape check of cached keys against query input. + expected_shape = (bsz, 1, embed_dim) + if expected_shape != x.shape: + raise ValueError('Autoregressive cache shape error, expected query ' + f'shape {expected_shape} instead got {x.shape}.') + # Update key, value caches with our new 1d spatial slices. + cached_key[:, cache_index:cache_index + 1, :] = k + cached_value[:, cache_index:cache_index + 1, :] = v + k = cached_key + v = cached_value + cache_index += 1 + # Causal mask for cached decoder self-attention: + # our single query position should only attend to those key + # positions that have already been generated and cached, + # not the remaining zero elements. + if attn_mask is not None: + raise ValueError('Attention mask has to be None for decode == True.') + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) + + # Update sequence length to account for complete sequence. + seq_len = k.size(1) + + # Rearrange q, k, v for multihead attention. + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Check dtype and shape of attention mask. + if not decode and attn_mask is not None: + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ + f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + # Ensure attn_mask's dim is 3. + if attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' + f'but should be {correct_3d_size}.') + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported") + # Reshape attention mask to be consistent with q, k, v. + attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) + + # Convert attention mask to float. + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, -1e10) + attn_mask = new_attn_mask + + # Adjust dropout_rate probability. + dropout_rate = self.dropout if self.training else 0.0 + + # Calculate attention. + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask, dropout_rate) + # Rearrange for output projection. + attn_output = attn_output.transpose(1, 2).contiguous().view( + bsz, tgt_len, embed_dim) + # Output projection. + attn_output = self.out_proj(attn_output) if decode: cache[name] = loc_cache - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights, cache - else: - return attn_output, attn_output_weights, cache - - -def _in_projection( - q: Tensor, - k: Tensor, - v: Tensor, - w_q: Tensor, - w_k: Tensor, - w_v: Tensor, - b_q: Optional[Tensor] = None, - b_k: Optional[Tensor] = None, - b_v: Optional[Tensor] = None, -) -> Tuple[Tensor, Tensor, Tensor]: - r"""Performs the in-projection step of the attention operation. This is simply - a triple of linear projections, with shape constraints on the weights which - ensure embedding dimension uniformity in the projected outputs. - Output is a triple containing projection tensors for query, key and value. - """ - eq, ek = q.size(-1), k.size(-1) - assert w_q.shape == (eq, eq), \ - f'Expecting query weights shape of {(eq, eq)}, but got {w_q.shape}' - assert w_k.shape == (eq, ek), \ - f'Expecting key weights shape of {(eq, ek)}, but got {w_k.shape}' - assert w_v.shape == (eq, ek), \ - f'Expecting value weights shape of {(eq, ek)}, but got {w_v.shape}' - assert b_q is None or b_q.shape == (eq,), \ - f'Expecting query bias shape of {(eq,)}, but got {b_q.shape}' - assert b_k is None or b_k.shape == (eq,), \ - f'Expecting key bias shape of {(eq,)}, but got {b_k.shape}' - assert b_v is None or b_v.shape == (eq,), \ - f'Expecting value bias shape of {(eq,)}, but got {b_v.shape}' - return torch.nn.functional.linear(q, w_q, b_q), \ - torch.nn.functional.linear(k, w_k, b_k), \ - torch.nn.functional.linear(v, w_v, b_v) - - -# Modified to create cache for autoregressive decoding. -def multi_head_attention_forward(query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_bias: Optional[Tensor], - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - dropout_rate: float, - out_proj_weight: Tensor, - out_proj_bias: Optional[Tensor], - training: bool = True, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - average_attn_weights: bool = True, - decode: bool = False, - cache: Optional[dict] = None, - max_len: Optional[int] = None) -> Any: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - See "Attention Is All You Need" for more details. - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_bias: input projection bias. - bias_k, bias_v: bias of the key and value sequences to be added at dim=0. - dropout_rate: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout_rate if is ``True``. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. - A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the - entries of each batch. - q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: - input projection weight and bias. - average_attn_weights: If true, indicates that the returned ``attn_weights`` - should be averaged across heads. - Otherwise, ``attn_weights`` are provided separately per head. - Note that this flag only has an effect when ``need_weights=True.``. - Default: True - decode: wether to use cache for autoregressive decoding or not. - cache: dict which contains cache for decoding for the current - MulitheadAttention module. - max_len: maximum sequence length, necessary for decoding cache. - Shape: - Inputs: - - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence - length, N is the batch size, E is the embedding dimension. - - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence - length, N is the batch size, E is the embedding dimension. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, - S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` - where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is - allowed to attend the unmasked positions. If a ByteTensor is provided, - the non-zero positions are not allowed to attend while the zero positions - will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. - If a FloatTensor is provided, it will be added to the attention weight. - Outputs: - - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target - sequence length, N is the batch size, E is the embedding dimension. - - attn_output_weights: Only returned when ``need_weights=True``. - If ``average_attn_weights=True``, returns - attention weights averaged across heads of shape :math:`(L, S)` when input - is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source - sequence length. If ``average_weights=False``, returns attention weights - per head of shape :math:`(num_heads, L, S)` when input is unbatched or - :math:`(N, num_heads, L, S)`. - """ - # Set up shape variables. - tgt_len, bsz, embed_dim = query.shape - src_len, _, _ = key.shape - assert embed_dim == embed_dim_to_check, \ - f'was expecting dimension of {embed_dim_to_check}, but got {embed_dim}' - if isinstance(embed_dim, torch.Tensor): - # `embed_dim` can be a tensor when JIT tracing. - head_dim = embed_dim.div(num_heads, rounding_mode='trunc') - else: - head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, \ - f'embed_dim {embed_dim} not divisible by num_heads {num_heads}' - # Allow MHA to have different embedding dimensions when separate projection - # weights are used. - assert key.shape[:2] == value.shape[:2], \ - (f"key's sequence and batch dims {key.shape[:2]} do not match value's " - f'{value.shape[:2]}') - - # Compute in-projection. - assert q_proj_weight is not None, \ - 'use_separate_proj_weight is True but q_proj_weight is None' - assert k_proj_weight is not None, \ - 'use_separate_proj_weight is True but k_proj_weight is None' - assert v_proj_weight is not None, \ - 'use_separate_proj_weight is True but v_proj_weight is None' - if in_proj_bias is None: - b_q = b_k = b_v = None - else: - b_q, b_k, b_v = in_proj_bias.chunk(3) - q, k, v = _in_projection( - query, key, value, q_proj_weight, k_proj_weight, - v_proj_weight, b_q, b_k, b_v) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if decode: - if cache is None: - cache = { - 'cached_key': - torch.zeros((bsz, max_len, embed_dim), - dtype=k.dtype, - device=k.device), - 'cached_value': - torch.zeros((bsz, max_len, embed_dim), - dtype=v.dtype, - device=v.device), - 'cache_index': - torch.tensor(0, dtype=torch.long, device=k.device), - } - cached_key = cache['cached_key'] - cached_value = cache['cached_value'] - cache_index = cache['cache_index'] - batch_size, max_length, num_features = cached_key.shape - assert batch_size == bsz, f'{batch_size} != {bsz}' - assert max_length == max_len, f'{max_length} != {max_len}' - assert num_features == embed_dim, f'{num_features} != {embed_dim}' - # Shape check of cached keys against query input. - expected_shape = (1, batch_size, num_features) - if expected_shape != query.shape: - raise ValueError('Autoregressive cache shape error, expected query shape ' - f'{expected_shape} instead got {query.shape}.') - # Update key, value caches with our new 1d spatial slices. - cached_key[:, cache_index:cache_index + 1, :] = k.transpose(dim0=0, dim1=1) - cached_value[:, cache_index:cache_index + 1, :] = v.transpose( - dim0=0, dim1=1) - k = cached_key.transpose(dim0=0, dim1=1) - v = cached_value.transpose(dim0=0, dim1=1) - cache_index += 1 - # Causal mask for cached decoder self-attention: - # our single query position should only attend to those key - # positions that have already been generated and cached, - # not the remaining zero elements. - if attn_mask is not None: - raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_length, device=k.device) >= - cache_index).reshape(1, max_length) - - # Prepare attention mask. - if not decode and attn_mask is not None: - if attn_mask.dtype == torch.uint8: - warnings.warn( - 'Byte tensor for attn_mask in nn.MultiheadAttention is deprecated.' - 'Use bool tensor instead.') - attn_mask = attn_mask.to(torch.bool) - else: - assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ - f'float, byte, and bool types are supported, not {attn_mask.dtype}' - # Ensure attn_mask's dim is 3. - if attn_mask.dim() == 2: - correct_2d_size = (tgt_len, src_len) - if attn_mask.shape != correct_2d_size: - raise RuntimeError( - f'The shape of the 2D attn_mask is {attn_mask.shape}, ' - f'but should be {correct_2d_size}.') - attn_mask = attn_mask.unsqueeze(0) - elif attn_mask.dim() == 3: - correct_3d_size = (bsz * num_heads, tgt_len, src_len) - if attn_mask.shape != correct_3d_size: - raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' - f'should be {correct_3d_size}.') - else: - raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported") - - # Add bias along batch dimension (currently second). - if bias_k is not None and bias_v is not None: - k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) - v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) - if attn_mask is not None: - attn_mask = F.pad(attn_mask, (0, 1)) - else: - assert bias_k is None - assert bias_v is None - - # Reshape q, k, v for multihead attention and make em batch first. - q = \ - q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) - k = \ - k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) - v = \ - v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) - - # Update source sequence length after adjustments. - src_len = k.size(1) - - # Convert mask to float. - if attn_mask is not None and attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, -1e10) - attn_mask = new_attn_mask - - # Adjust dropout_rate probability. - if not training: - dropout_rate = 0.0 - - # Calculate attention and out projection. - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) - attn_output = attn_output.transpose(0, 1).contiguous().view( - tgt_len * bsz, embed_dim) - attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) - attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) - - if need_weights: - q_scaled = q / math.sqrt(q.shape[-1]) - - if attn_mask is not None: - attn_output_weights = torch.baddbmm(attn_mask, - q_scaled, - k.transpose(-2, -1)) - else: - attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) - - # Optionally average attention weights over heads. - attn_output_weights = attn_output_weights.view(bsz, - num_heads, - tgt_len, - src_len) - if average_attn_weights: - attn_output_weights = attn_output_weights.sum(dim=1) / num_heads - return attn_output, attn_output_weights, cache - else: - return attn_output, None, cache + return attn_output, cache diff --git a/datasets/README.md b/datasets/README.md index 5afe257fe..5ff0e18a7 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -1,28 +1,125 @@ # Dataset Setup -Use `dataset_setup.py` to download datasets, for example: -``` +TL;DR: +Use `dataset_setup.py` to download datasets. +Usage: +```bash python3 datasets/dataset_setup.py \ --data_dir=~/data \ - --ogbg + -- + -- +``` +The complete benchmark uses 6 datasets: +- OGBG +- WMT +- FastMRI +- Imagenet +- Criteo 1TB +- Librispeech + + +Some dataset setups will require you to sign a third party agreement with the dataset owners in order to get the donwload URLs. + +# Per dataset instructions +## Environment + +### Set data directory (Docker container) +If you are running the `dataset_setup.py` script from a Docker container, please +make sure the data directory is mounted to a directory on your host with +-v flag. If you are following instructions from the README you will have used +the `-v $HOME/data:/data` flag in the `docker run` command. This will mount +the `$HOME/data` directory to the `/data` directory in the container. +In this case set --data_dir to `\data`. +```bash +DATA_DIR='/data' +``` +### Set data directory (on host) +Alternatively, if you are running the data download script directly on your host, feel free +to choose whatever directory you find suitable, further submission instructions +assume the data is stored in `~/data`. +```bash +DATA_DIR='~/data' +``` +#### Start tmux session (Recommended) +If running the dataset_setup.py on directly on host it is recommended to run +the dataset_setup.py script in a tmux session because some of the data downloads may +take several hours. To avoid your setup being interrupted start a tmux session: +```bash +tmux new -s data_setup +``` + + +## Datasets + +### OGBG +From `algorithmic-efficiency` run: +```bash +python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR/ogbg \ +--ogbg ``` -This will require the same pip dependencies as `submission_runner.py`. +### WMT +From `algorithmic-efficiency` run: +```bash +python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--wmt +``` -Some datasets require signing a form before downloading: -FastMRI: -Fill out form on https://fastmri.med.nyu.edu/ and run this script with the -links that are emailed to you for "knee_singlecoil_train" and -"knee_singlecoil_val". +## FastMRI +Fill out form on https://fastmri.med.nyu.edu/. After filling out the form +you should get an email containing the URLS for "knee_singlecoil_train", +"knee_singlecoil_val" and "knee_singlecoil_test". -ImageNet: -Register on https://image-net.org/ and run this script with the links to the -ILSVRC2012 train and validation images. +```bash +python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--fastmri \ +--fastmri_knee_singlecoil_train_url '' \ +--fastmri_knee_singlecoil_val_url '' \ +--fastmri_knee_singlecoil_test_url '' +``` -Note for tfds ImageNet, you may have to increase the max number of files allowed -open at once using `ulimit -n 8192`. +## ImageNet +Register on https://image-net.org/ and follow directions to obtain the +URLS for the ILSVRC2012 train and validation images. -Note that in order to avoid potential accidental deletion, this script does NOT +Imagenet dataset processsing is resource intensive. To avoid potential +ResourcExhausted errors increase the maximum number of open file descriptors: +```bash +ulimit -n 8192 +``` + +The imagenet data pipeline differs between the pytorch and jax workloads. +Therefore, you will have to specify the framework (pytorch or jax) through theframework flag. + +```bash +python3 datasets/dataset_setup.py \ +--data_dir /data \ +--imagenet \ +--temp_dir $DATA_DIR/tmp \ +--imagenet_train_url \ +--imagenet_val_url \ +--framework jax + +``` + +Note that some functions use subprocess.Popen(..., shell=True), which can be +dangerous if the user injects code into the --data_dir or --temp_dir flags. We +do some basic sanitization in main(), but submitters should not let untrusted +users run this script on their systems. + +## Criteo1tb +```bash +python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--criteo1tb +``` + +### Clean up +In order to avoid potential accidental deletion, this script does NOT delete any intermediate temporary files (such as zip archives) without a user confirmation. Deleting temp files is particularly important for Criteo 1TB, as there can be multiple copies of the dataset on disk during preprocessing if @@ -31,17 +128,21 @@ can pass --interactive_deletion=false and then all files will be downloaded to the provided --temp_dir, and the user can manually delete these after downloading has finished. -Note that some functions use subprocess.Popen(..., shell=True), which can be -dangerous if the user injects code into the --data_dir or --temp_dir flags. We -do some basic sanitization in main(), but submitters should not let untrusted -users run this script on their systems. ## Librispeech +To download, train a tokenizer and preprocess the librispeech dataset: +```bash +python3 datasets/dataset_setup.py \ +--data_dir librispeech \ +--temp_dir $DATA_DIR/tmp \ +--librispeech +``` -### Training SPM Tokenizer -This step trains a simple sentence piece tokenizer over librispeech training data. -This tokenizer is then used in later preprocessing step to tokenize transcripts. -This command will generate `spm_model.vocab` file in `$DATA_DIR/librispeech`: +### Notes on librispeech preprocessing +#### Training SPM Tokenizer + A simple sentence piece tokenizer is trained over librispeech training + data. This tokenizer is then used in later preprocessing step to tokenize transcripts. +This command generates `spm_model.vocab` file in `$DATA_DIR/librispeech`: ```bash python3 librispeech_tokenizer.py --train --data_dir=$DATA_DIR/librispeech ``` @@ -51,9 +152,12 @@ The trained tokenizer can be loaded back to do sanity check by tokenizing + de-t librispeech_tokenizer.py --data_dir=$DATA_DIR/librispeech ``` -### Preprocessing Script +#### Preprocessing Script The preprocessing script will generate `.npy` files for audio data, `features.csv` which has paths to saved audio `.npy`, and `trans.csv` which has paths to `features.csv` and transcription data. ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + + + diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index d1636a3e5..e7f8c1d13 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -170,7 +170,7 @@ flags.DEFINE_string( 'fastmri_knee_singlecoil_test_url', None, - 'Only necessary if you want this script to `wget` the FastMRI validation ' + 'Only necessary if you want this script to `wget` the FastMRI test ' 'split. If not, you can supply the path to --data_dir in ' 'submission_runner.py.') @@ -207,13 +207,11 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion): def _download_url(url, data_dir, name=None): - - data_dir = os.path.expanduser(data_dir) if not name: file_path = os.path.join(data_dir, url.split('/')[-1]) else: file_path = os.path.join(data_dir, name) - logging.info(f'About to download to {file_path}') + logging.info(f'Downloading URL {url} to {file_path}') response = requests.get(url, stream=True, timeout=600) total_size_in_bytes = int(response.headers.get('Content-length', 0)) @@ -230,7 +228,7 @@ def _download_url(url, data_dir, name=None): break logging.info('Invalid response. Try again.') if overwrite == 'n': - logging.info('Skipping download to {}'.format(file_path)) + logging.info(f'Skipping download URL {url} to {file_path}') return with open(file_path, 'wb') as f: @@ -300,7 +298,7 @@ def download_criteo1tb(data_dir, logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}') p = subprocess.Popen(unzip_cmd, shell=True) p.communicate() - _maybe_prompt_for_deletion(all_days_zip_filepath, interactive_deletion) + _maybe_prompt_for_deletion([all_days_zip_filepath], interactive_deletion) # Unzip the individual days. processes = [] @@ -318,9 +316,9 @@ def download_criteo1tb(data_dir, _maybe_prompt_for_deletion(gz_paths, interactive_deletion) # Split into files with 5M lines each: day_1.csv -> day_1_[0-39].csv. + unzipped_paths = [] for batch in range(6): batch_processes = [] - unzipped_paths = [] for day_offset in range(4): day = batch * 4 + day_offset unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') @@ -332,7 +330,7 @@ def download_criteo1tb(data_dir, batch_processes.append(subprocess.Popen(split_cmd, shell=True)) for p in batch_processes: p.communicate() - _maybe_prompt_for_deletion(unzipped_paths, interactive_deletion) + _maybe_prompt_for_deletion(unzipped_paths, interactive_deletion) def download_cifar(data_dir, framework): @@ -355,7 +353,7 @@ def extract_filename_from_url(url, start_str='knee', end_str='.xz'): end = url.find(end_str) if failure in (start, end): raise ValueError( - f'Unable to locate filename wrapped in {start}--{end} in {url}') + f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}') end += len(end_str) # make it inclusive return url[start:end] @@ -364,7 +362,6 @@ def download_fastmri(data_dir, fastmri_train_url, fastmri_val_url, fastmri_test_url): - data_dir = os.path.join(data_dir, 'fastmri') # Download fastmri train dataset knee_train_filename = extract_filename_from_url(fastmri_train_url) @@ -393,7 +390,7 @@ def extract(source, dest): if not os.path.exists(dest): os.path.makedirs(dest) logging.info(f'Extracting {source} to {dest}') - tar = tarfile.open(source) + tar = tarfile.open(source, 'r:xz') logging.info('Opened tar') tar.extractall(dest) @@ -430,17 +427,28 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) + imagenet_jax_data_dir = os.path.join(data_dir, 'jax') + manual_download_dir = os.path.join(imagenet_jax_data_dir, + 'downloads', + 'manual') + imagenet_train_download_filepath = os.path.join(manual_download_dir, + IMAGENET_TRAIN_TAR_FILENAME) + imagenet_val_download_filepath = os.path.join(manual_download_dir, + IMAGENET_VAL_TAR_FILENAME) + # Download imagnet train dataset - if not os.path.exists(imagenet_train_filepath): + if not os.path.exists(imagenet_train_filepath) and not os.path.exists( + imagenet_train_download_filepath): logging.info( 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) - _download_url(url=imagenet_train_url, data_dir=data_dir).download() + _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset - if not os.path.exists(imagenet_val_filepath): + if not os.path.exists(imagenet_val_filepath) and not os.path.exists( + imagenet_val_download_filepath): logging.info('Downloading imagenet validation dataset from {}'.format( imagenet_val_url)) - _download_url(url=imagenet_val_url, data_dir=data_dir).download() + _download_url(url=imagenet_val_url, data_dir=data_dir) # Download imagenet test set download_imagenet_v2(data_dir) @@ -460,6 +468,7 @@ def setup_imagenet(data_dir, framework=None): def setup_imagenet_jax(data_dir): train_tar_file_path = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) + test_dir_path = os.path.join(data_dir, 'imagenet_v2') # Setup jax dataset dir imagenet_jax_data_dir = os.path.join(data_dir, 'jax') @@ -472,17 +481,20 @@ def setup_imagenet_jax(data_dir): logging.info('Checking if tar files already exists in jax/downloads/manual.') if not os.path.exists( os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - logging.info('Copying {} to {}'.format(train_tar_file_path, - manual_download_dir)) + logging.info('Moving {} to {}'.format(train_tar_file_path, + manual_download_dir)) shutil.move(train_tar_file_path, manual_download_dir) if not os.path.exists( os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - logging.info('Copying {} to {}'.format(val_tar_file_path, - manual_download_dir)) + logging.info('Moving {} to {}'.format(val_tar_file_path, + manual_download_dir)) shutil.move(val_tar_file_path, manual_download_dir) + if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): + logging.info('Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) + shutil.move(test_dir_path, + os.path.join(imagenet_jax_data_dir, 'imagenet_v2')) logging.info('Preparing imagenet data.') - resource.setrlimit(resource.RLIMIT_NOFILE, - (resource.RLIM_INFINITY, resource.RLIM_INFINITY)) ds_builder = tfds.builder( 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir)) ds_builder.download_and_prepare() @@ -492,6 +504,7 @@ def setup_imagenet_jax(data_dir): def setup_imagenet_pytorch(data_dir): train_tar_file_path = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) + test_dir_path = os.path.join(data_dir, 'imagenet_v2') # Setup jax dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch') @@ -499,13 +512,18 @@ def setup_imagenet_pytorch(data_dir): os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'train')) os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'val')) - # Copy tar file into pytorch directory - logging.info('Copying {} to {}'.format(train_tar_file_path, - imagenet_pytorch_data_dir)) + # Move tar files and imagenet_v2 into pytorch directory + logging.info('Moving {} to {}'.format(train_tar_file_path, + imagenet_pytorch_data_dir)) shutil.move(train_tar_file_path, imagenet_pytorch_data_dir) - logging.info('Copying {} to {}'.format(val_tar_file_path, - imagenet_pytorch_data_dir)) + logging.info('Moving {} to {}'.format(val_tar_file_path, + imagenet_pytorch_data_dir)) shutil.move(val_tar_file_path, imagenet_pytorch_data_dir) + if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): + logging.info('Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) + shutil.move(test_dir_path, + os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')) # Extract train data\ logging.info('Extracting imagenet train data') @@ -549,11 +567,12 @@ def download_librispeech(dataset_dir, tmp_dir): # After extraction the result is a folder named Librispeech containing audio # files in .flac format along with transcripts containing name of audio file # and corresponding transcription. - tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech') + tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech') extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') - final_data_dir = os.path.join(dataset_dir, 'librispeech_processed') + final_data_dir = os.path.join(dataset_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) + _maybe_mkdir(final_data_dir) for split in ['dev', 'test']: for version in ['clean', 'other']: @@ -597,11 +616,13 @@ def download_mnist(data_dir): def download_ogbg(data_dir): + data_dir = os.path.join(data_dir, 'ogbg') tfds.builder('ogbg_molpcba:0.1.3', data_dir=data_dir).download_and_prepare() def download_wmt(data_dir): """WMT14 and WMT17 de-en.""" + data_dir = os.path.join(data_dir, 'wmt') for ds_name in ['wmt14_translate/de-en:1.0.0', 'wmt17_translate/de-en:1.0.0']: dataset_builder = tfds.builder(ds_name, data_dir=data_dir) dataset_builder.download_and_prepare() diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index 0968f2a00..acdaa8e98 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -32,6 +32,7 @@ 'train-clean-360': 104014, 'train-other-500': 148688, 'test-clean': 2620, + 'test-other': 2939, 'dev-clean': 2703, 'dev-other': 2864, } diff --git a/setup.cfg b/setup.cfg index 6f53cd51b..a7ce5ebb2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,6 +115,7 @@ jax_core_deps = # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. chex==0.1.7 + ml_dtypes==0.2.0 # JAX CPU jax_cpu = diff --git a/submission_runner.py b/submission_runner.py index f4ee32ede..2289d39d3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -133,6 +133,11 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_integer( + 'rng_seed', + None, + 'Value of rng seed. If None, a random seed will' + 'be generated from hardware.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -173,6 +178,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -267,10 +273,9 @@ def train_once( global_step, preemption_count, checkpoint_dir=log_dir) - meta_data = logger_utils.get_meta_data(workload) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.write_json(meta_file_name, meta_data) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -449,7 +454,8 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True): + save_checkpoints: Optional[bool] = True, + rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -496,7 +502,8 @@ def score_submission_on_workload(workload: spec.Workload, all_metrics = [] for hi, hyperparameters in enumerate(tuning_search_space): # Generate a new seed from hardware sources of randomness for each trial. - rng_seed = struct.unpack('I', os.urandom(4))[0] + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) rng = prng.PRNGKey(rng_seed) # Because we initialize the PRNGKey with only a single 32 bit int, in the @@ -528,7 +535,9 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, rng, + hyperparameters, + rng_seed, + rng, profiler, max_global_steps, tuning_dir_name, @@ -545,7 +554,8 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) else: - rng_seed = struct.unpack('q', os.urandom(8))[0] + if not rng_seed: + rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) # If the submission is responsible for tuning itself, we only need to run it # once and return the total time. @@ -554,7 +564,7 @@ def score_submission_on_workload(workload: spec.Workload, workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - None, rng, profiler, max_global_steps, log_dir, + None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) return score @@ -610,7 +620,8 @@ def main(_): tuning_search_space=FLAGS.tuning_search_space, num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints) + save_checkpoints=FLAGS.save_checkpoints, + rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 761da427b..9a95f3656 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -35,7 +35,17 @@ def key_transform(k): return tuple(new_key) -sd_transform = None +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 52c96481c..806022687 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -47,20 +47,38 @@ def sd_transform(sd): out = {} for k in sd: k_str = ''.join(k) - if 'Dense' in k_str: - new_key = (*k[:2], 'MlpBlock_0', *k[2:]) - out[new_key] = sd[k] - elif 'SelfAttention' in k_str: + if 'SelfAttention' in k_str: new_key = list(k) - if '_' in new_key[-1]: - qkv = {'q': 'query', 'k': 'key', 'v': 'value'}[new_key[-1][0]] - new_key[-1] = qkv - new_key.append('kernel') new_key = [ i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' for i in new_key ] - new_key = tuple(new_key) + if 'SelfAttention_0' in k_str: + if new_key[-2] == 'Dense_0': + # qkv + for name, value in zip(('query', 'key', 'value'), sd[k].chunk(3)): + out[(*new_key[:-2], name, new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_1': + # out + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] + pass + else: + if new_key[-2] == 'Dense_0': + #q + out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] + pass + elif new_key[-2] == 'Dense_1': + # kv + for name, value in zip(('key', 'value'), sd[k].chunk(2)): + out[(*new_key[:-2], name, new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_2': + # out + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] + pass + elif 'Dense' in k_str: + new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] elif 'LayerNorm' in k_str: new_key = list(k) diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index 5b33d8b62..b67625213 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -55,9 +55,15 @@ def test_param_shapes(workload): jax_workload.param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) - if workload == 'criteo1tb': - # The PyTorch implementation divides the embedding matrix - # into 3 chunks. + if workload == 'wmt': + # The PyTorch transformer for WMT is implemented with fused linear layers + # for the projection of QKV inside of the MultiheadAttention module. + # Two weight matrices for each of the two self-attention layers less and one + # less for the encoder-decoder attention layer -> 5 weight matrices less. + # We have 6 encoder/decoder layers, hence 30 weight matrices less in total. + assert len(jax_param_shapes) == len(pytorch_param_shapes) + 30 + elif workload == 'criteo1tb': + # The PyTorch implementation divides the embedding matrix into 3 chunks. assert len(jax_param_shapes) == len(pytorch_param_shapes) - 3 else: assert len(jax_param_shapes) == len(pytorch_param_shapes) diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 45e855759..7cf8f63c3 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -71,6 +71,12 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), } + num_kv = { + 'jax': + jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + 'pytorch': + pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + } num_q = { 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), @@ -96,11 +102,13 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), } qkv_match = num_qkv['jax'] == num_qkv['pytorch'] + kv_match = num_kv['jax'] == num_kv['pytorch'] q_match = num_q['jax'] == num_q['pytorch'] k_match = num_k['jax'] == num_k['pytorch'] v_match = num_v['jax'] == num_v['pytorch'] bias_match = num_bias['jax'] == num_bias['pytorch'] - qkv_match = qkv_match and q_match and k_match and v_match and bias_match + qkv_match = ( + qkv_match and kv_match and q_match and k_match and v_match and bias_match) # We subtract 2 * num_qkv from the number of biases because there are 2 # missing for each of q, k, v. @@ -112,7 +120,12 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv['pytorch'] and (num_qkv['pytorch'] != 0 and (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch'])) - qkv_match = qkv_match or jax_qkv_match or pytorch_qkv_match + pytorch_kv_match = ( + num_q['jax'] == num_k['jax'] == num_v['jax'] == + num_qkv['pytorch'] + num_kv['pytorch'] and + num_q['pytorch'] == num_kv['pytorch']) + qkv_match = ( + qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match) return qkv_match @@ -149,6 +162,7 @@ def test_param_types(workload_name): # Check if total number of each type match. attention_keys = { spec.ParameterType.ATTENTION_QKV, + spec.ParameterType.ATTENTION_KV, spec.ParameterType.ATTENTION_Q, spec.ParameterType.ATTENTION_K, spec.ParameterType.ATTENTION_V, diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index fec1f9085..a1b64a573 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -42,14 +42,14 @@ def test_workload(self): jax_logs = '/tmp/jax_log.pkl' pyt_logs = '/tmp/pyt_log.pkl' run( - f'python3 tests/reference_algorithm_tests.py --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}' + f'python3 -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}' f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', shell=True, stdout=DEVNULL, stderr=STDOUT, check=True) run( - f'torchrun --standalone --nnodes 1 --nproc_per_node 8 tests/reference_algorithm_tests.py --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}' + f'torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}' f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', shell=True, stdout=DEVNULL,