-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jitter Noise added to input being passed to experts in Switch Transformers #33969
Comments
Hey! where is the |
Hello @ArthurZucker! According to the pseudo code given in Figure 15 and Figure 16 of the original paper, the input to the experts doesn't contain the additional jitter noise. |
Did you have a look at the actual code provided by the authors? 🤗 |
Upon inspecting their original implementation, specifically here It seems that |
Do we agree that right after this: gate_inputs = mtf.to_float(inputs)
# Input perturbations
if policy == "input_dropout":
gate_inputs = mtf.dropout(
gate_inputs,
is_training=train,
keep_prob=1.0 - hparams.moe_switch_dropout)
elif train and policy == "input_jitter":
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
hparams.moe_switch_jitter) we have jitter applied and gate inputshave added jitter noice no? |
Only |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Awaiting response |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
sorry for being late, okay, do you want to open a pr for a fix? |
@ArthurZucker I saw this open & wrote a small fix . hope you can review this and merge if issue is resolved . I'll acknowledge any comments you have on it as well if required . also @karan-uppal3 It would be great of you could verify If this meets the requirements or not . I'll make any changes accordingly . |
System Info
System Info
Who can help?
@ArthurZucker @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The output is
which ideally should give True and the mean difference should be zero.
This is because in
SwitchTransformersTop1Router
, thehidden_states
are multiplied with jitter noise which persists even when you pass it to the experts.transformers/src/transformers/models/switch_transformers/modeling_switch_transformers.py
Lines 159 to 161 in e71a01a
Expected behavior
Ideally, no jitter noise should be present when passing the input to the experts, returning True and the mean difference as 0.
The text was updated successfully, but these errors were encountered: