Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Jitter Noise Passing to Experts in Switch Transformers #33969 #35847

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

sambhavnoobcoder
Copy link

Issue Description

This pull request addresses a bug in the Switch Transformers architecture where the jitter noise (intended to be applied only for routing decisions) was also being unintentionally applied to the expert inputs.

Fixes : #33969

Problem Statement

In Switch Transformers, a small amount of jitter noise is added to the inputs at routing time to ensure route diversity. However, these jittered inputs were incorrectly passed along to the experts, which contradicts the original paper’s design and led to unexpected discrepancies in outputs.

Root Cause

It was discovered that the code used the same hidden states for both routing and expert processing. When jitter noise was enabled in training mode, it directly modified these states in place, causing the experts to receive noisy inputs.

Implementation

  1. We now clone the original hidden states before applying jitter noise.
  2. A separate copy is used exclusively for computing router logits and probabilities.
  3. The unchanged hidden states are then fed to the experts to maintain the original semantics.

Screenshot

Screenshot 2025-01-23 at 1 14 26 AM

Test Cases

  1. test_router_training_mode
    • Objective: Ensures that jitter noise is only applied during training.
    • Checks that outputs differ between consecutive runs (due to noise) but original inputs remain unchanged.

  2. **test_router_jitter_noise_separation **
    • Objective: Verifies that jitter noise affects only the router’s internal computations and not the expert inputs.
    • Confirms the logits differ when jitter is applied, while the main input stays the same.

  3. test_expert_inputs_consistency
    • Objective: Asserts that all expert inputs remain consistent, even when jitter is applied during training.
    • Uses a forward hook on the first expert to capture its inputs across multiple runs and compares them.

With these changes and test additions, we ensure that Switch Transformers adhere to the original design while preserving backward compatibility and correctness.

cc : @ArthurZucker

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jitter Noise added to input being passed to experts in Switch Transformers
1 participant