Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jitter Noise added to input being passed to experts in Switch Transformers #33969

Open
4 tasks
karan-uppal3 opened this issue Oct 5, 2024 · 11 comments · May be fixed by #35847
Open
4 tasks

Jitter Noise added to input being passed to experts in Switch Transformers #33969

karan-uppal3 opened this issue Oct 5, 2024 · 11 comments · May be fixed by #35847
Labels
bug Core: Modeling Internals of the library; Models.

Comments

@karan-uppal3
Copy link
Contributor

System Info

System Info

  • transformers version: 4.44.2
  • Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
import torch.nn as nn
from transformers import (
    SwitchTransformersConfig,
    SwitchTransformersTop1Router,
)
from transformers.models.switch_transformers.modeling_switch_transformers import SwitchTransformersDenseActDense


class MySwitchTransformersSparseMLP(nn.Module):
    r"""
    Implementation of the Switch Transformers Sparse MLP module.
    """

    def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
        super().__init__()
        # Step 1: Get the correct router according to its class
        self.router = SwitchTransformersTop1Router(config)

        # Step 2: Get the experts
        self.experts = nn.ModuleDict()
        for idx in range(config.num_experts):
            self.experts[f"expert_{idx}"] = expert_class(config)

    def forward(self, hidden_states):
        r"""
        Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

        1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
        and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
        hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).

        2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
        expert the corresponding hidden states.

        """

        prev_save = hidden_states.clone()

        # Step 1: Get the router_mask from the router as wel as the probabilities
        router_mask, router_probs, router_logits = self.router(hidden_states)
        expert_index = torch.argmax(router_mask, dim=-1)

        print(torch.allclose(prev_save, hidden_states))
        print(torch.mean(prev_save - hidden_states))

        # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
        # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

        next_states = hidden_states.clone()

        router_mask = router_mask.bool()
        batch_size, seq_len, num_experts = router_mask.shape
        idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
        idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
            0
        ].tolist()  # length: number of "activated" expert / value: index
        for idx in idx_mask:
            next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
                hidden_states[router_mask[:, :, idx]]
            )

        hidden_states = router_probs * next_states
        return hidden_states, (router_logits, expert_index)

config = SwitchTransformersConfig()
model = MySwitchTransformersSparseMLP(config)

model.train()
in_data = torch.ones(1, 1, 768)
out = model(in_data)

The output is

False
tensor(-0.0001)

which ideally should give True and the mean difference should be zero.

This is because in SwitchTransformersTop1Router, the hidden_states are multiplied with jitter noise which persists even when you pass it to the experts.

if self.training and self.jitter_noise > 0:
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

Expected behavior

Ideally, no jitter noise should be present when passing the input to the experts, returning True and the mean difference as 0.

@ArthurZucker
Copy link
Collaborator

Hey! where is the ideal scenario coming from? 🤗 we tried to follow the original implementation on this!

@karan-uppal3
Copy link
Contributor Author

karan-uppal3 commented Oct 7, 2024

Hello @ArthurZucker! According to the pseudo code given in Figure 15 and Figure 16 of the original paper, the input to the experts doesn't contain the additional jitter noise.

@LysandreJik LysandreJik added the Core: Modeling Internals of the library; Models. label Oct 11, 2024
@ArthurZucker
Copy link
Collaborator

Did you have a look at the actual code provided by the authors? 🤗

@karan-uppal3
Copy link
Contributor Author

karan-uppal3 commented Oct 25, 2024

Upon inspecting their original implementation, specifically here

It seems that gate_inputs is created as a copy of inputs and is only used for the gating function. On the other hand, inputs is what is used while passing to the experts

@ArthurZucker
Copy link
Collaborator

Do we agree that right after this:

  gate_inputs = mtf.to_float(inputs)

  # Input perturbations
  if policy == "input_dropout":
    gate_inputs = mtf.dropout(
        gate_inputs,
        is_training=train,
        keep_prob=1.0 - hparams.moe_switch_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_switch_jitter)

we have jitter applied and gate inputshave added jitter noice no?

@karan-uppal3
Copy link
Contributor Author

Only gate_inputs have jitter noise added. However, inputs to the experts does not have jitter noise

Copy link

github-actions bot commented Dec 1, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@karan-uppal3
Copy link
Contributor Author

Awaiting response

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 3, 2025
@ArthurZucker ArthurZucker reopened this Jan 21, 2025
@ArthurZucker
Copy link
Collaborator

sorry for being late, okay, do you want to open a pr for a fix?

@sambhavnoobcoder
Copy link

@ArthurZucker I saw this open & wrote a small fix . hope you can review this and merge if issue is resolved . I'll acknowledge any comments you have on it as well if required . also @karan-uppal3 It would be great of you could verify If this meets the requirements or not . I'll make any changes accordingly .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Core: Modeling Internals of the library; Models.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants