Fix Jitter Noise Passing to Experts in Switch Transformers #33969 #35847
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Issue Description
This pull request addresses a bug in the Switch Transformers architecture where the jitter noise (intended to be applied only for routing decisions) was also being unintentionally applied to the expert inputs.
Fixes : #33969
Problem Statement
In Switch Transformers, a small amount of jitter noise is added to the inputs at routing time to ensure route diversity. However, these jittered inputs were incorrectly passed along to the experts, which contradicts the original paper’s design and led to unexpected discrepancies in outputs.
Root Cause
It was discovered that the code used the same hidden states for both routing and expert processing. When jitter noise was enabled in training mode, it directly modified these states in place, causing the experts to receive noisy inputs.
Implementation
Screenshot
Test Cases
test_router_training_mode
• Objective: Ensures that jitter noise is only applied during training.
• Checks that outputs differ between consecutive runs (due to noise) but original inputs remain unchanged.
**test_router_jitter_noise_separation **
• Objective: Verifies that jitter noise affects only the router’s internal computations and not the expert inputs.
• Confirms the logits differ when jitter is applied, while the main input stays the same.
test_expert_inputs_consistency
• Objective: Asserts that all expert inputs remain consistent, even when jitter is applied during training.
• Uses a forward hook on the first expert to capture its inputs across multiple runs and compares them.
With these changes and test additions, we ensure that Switch Transformers adhere to the original design while preserving backward compatibility and correctness.
cc : @ArthurZucker