Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BN Fixes #783

Merged
merged 3 commits into from
Oct 29, 2024
Merged

BN Fixes #783

merged 3 commits into from
Oct 29, 2024

Conversation

adefazio
Copy link
Contributor

@adefazio adefazio commented Sep 5, 2024

There are some subtle issues with how BatchNorm is handled in the PyTorch version of the code. Currently, workload.model_fn has an update_batch_norm parameter, which in theory should allow the submission to control whether the batch-norm statistics are updated during a forward pass. The issues are the following:

  • The update_batch_norm_fn function stores the old momentum parameter for each batchnorm layer in a momentum_backup variable, so it can be restored later, before zeroing the parameter. However, if it is called with update_batch_norm=False twice in a row, it overwrites the momentum_backup with 0 on the second call, so momentum then remains zero for the remainder of training.
  • In PyTorch's bultin BatchNorm, 0 indicates that the momentum buffer shouldn't be updated. This is the opposite of how EMA momentum is usually done (i.e. in Adam), where 1 would indicate that it shouldn't be updated, and 0 means it's set to the latest value at every step. The custom BatchNorm modules used in the two librispeech workloads follows this second, more standard convention instead. However, the update_batch_norm_fn sets the momentum to zero for all three layer types, resulting in incorrect behavior for the librispeech workloads.
  • The update_batch_norm_fn sets the BN layers to eval mode. This doesn't make sense as it prevents the use-case where you use batch-computed statistics (train mode) without also updating the running statistics. The BN layers can bet set to eval mode separately by passing in ForwardPassMode.EVAL to the forward pass, so removing this .eval() call doesn't prevent the submission from using eval mode during a forward pass.

This PR changes switch the custom BN code to follow the BN convention so that momentum=0 doesn't update the running buffers. It also fixes the issues in the update_batch_norm_fn function mentioned above.

priyakasimbeg and others added 3 commits July 30, 2024 07:29

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Dev -> Main

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Dev -> main
@adefazio adefazio requested a review from a team as a code owner September 5, 2024 15:39
Copy link

github-actions bot commented Sep 5, 2024

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Sep 5, 2024

Thanks for spotting all these issues. I agree, we should incorporate these fixes.
I spotted one more subtle issue in our JAX code similar to 3 here.

  def __call__(self,
               x: spec.Tensor,
               update_batch_norm: bool = True) -> spec.Tensor:
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    norm = functools.partial(
        nn.BatchNorm,
        use_running_average=not update_batch_norm,
        momentum=0.9,
        epsilon=1e-5,
        dtype=self.dtype)

This prevents the use case where you (don't) want to use the running average in train mode and (don't) want update the batch_norm statistics. Maybe we need an extra arg in the call functions to distinguish between train and eval mode (or just whether or not to use_running_average) instead of inferring from the update_batch_norm arg?

@adefazio
Copy link
Contributor Author

adefazio commented Sep 6, 2024

recheck

@priyakasimbeg priyakasimbeg changed the base branch from main to dev September 12, 2024 19:20
Copy link
Contributor

@priyakasimbeg priyakasimbeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. JAX changes are in follow up PR.
Nit: Can you run `yapf -i -r -vv -p on the files that are failing the Linting and yapf tests'?

@priyakasimbeg priyakasimbeg mentioned this pull request Oct 17, 2024
@priyakasimbeg priyakasimbeg merged commit b24812f into mlcommons:dev Oct 29, 2024
30 of 36 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Oct 29, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants