Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config fix #39

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fairseq/tasks/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def close(self):
pass


class WeightIterator(object):
class WeightIterator:
def __init__(self, weights, seed):
self.weights = weights
self.seed = seed
Expand Down
158 changes: 28 additions & 130 deletions torchscale/architecture/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@
# Licensed under The MIT License [see LICENSE for details]


class EncoderConfig(object):
class Config:
def __init__(self, **kwargs):
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.normalize_output = kwargs.pop("normalize_output", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
Expand All @@ -35,31 +29,15 @@ def __init__(self, **kwargs):
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_encoder_input_output_embed = kwargs.pop(
"share_encoder_input_output_embed", False
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Vision
self.img_size = kwargs.pop("img_size", 224)
self.patch_size = kwargs.pop("patch_size", 16)
self.in_chans = kwargs.pop("in_chans", 3)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)

if self.deepnorm:
self.encoder_normalize_before = False
self.subln = False
if self.subln:
self.encoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
Expand All @@ -71,138 +49,58 @@ def override(self, args):
self.__dict__[hp] = getattr(args, hp, None)


class DecoderConfig(object):
class EncoderConfig(Config):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
super().__init__(**kwargs)
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.normalize_output = kwargs.pop("normalize_output", True)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
self.share_encoder_input_output_embed = kwargs.pop(
"share_encoder_input_output_embed", False
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
# Vision
self.img_size = kwargs.pop("img_size", 224)
self.patch_size = kwargs.pop("patch_size", 16)
self.in_chans = kwargs.pop("in_chans", 3)


if self.deepnorm:
self.decoder_normalize_before = False
self.encoder_normalize_before = False
self.subln = False
if self.subln:
self.decoder_normalize_before = True
self.encoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
assert self.moe_freq > 0 and self.moe_expert_count > 0


def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)


class EncoderDecoderConfig(object):
class DecoderConfig(Config):
def __init__(self, **kwargs):
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
super().__init__(**kwargs)
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)

if self.deepnorm:
self.encoder_normalize_before = False
self.decoder_normalize_before = False
self.subln = False
if self.subln:
self.encoder_normalize_before = True
self.decoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
assert self.moe_freq > 0 and self.moe_expert_count > 0

def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)

class EncoderDecoderConfig(EncoderConfig, DecoderConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)

6 changes: 4 additions & 2 deletions torchscale/architecture/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from fairscale.nn import checkpoint_wrapper, wrap

from torchscale.architecture.utils import init_bert_params
from torchscale.architecture.config import DecoderConfig, EncoderConfig
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate

try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
Expand All @@ -23,7 +25,7 @@
class DecoderLayer(nn.Module):
def __init__(
self,
args,
args: DecoderConfig,
depth,
is_moe_layer=False,
is_encoder_decoder=False,
Expand Down Expand Up @@ -209,7 +211,7 @@ def forward(
class Decoder(nn.Module):
def __init__(
self,
args,
args: DecoderConfig,
embed_tokens=None,
embed_positions=None,
output_projection=None,
Expand Down
11 changes: 8 additions & 3 deletions torchscale/architecture/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.nn import LayerNorm

from torchscale.architecture.utils import init_bert_params
from torchscale.architecture.config import EncoderConfig
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention
Expand All @@ -23,7 +24,11 @@


class EncoderLayer(nn.Module):
def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
def __init__(self,
args: EncoderConfig,
depth,
is_moe_layer: bool = False,
is_encoder_decoder: bool = False):
super().__init__()
self.args = args
self.embed_dim = args.encoder_embed_dim
Expand Down Expand Up @@ -165,11 +170,11 @@ def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiwa
class Encoder(nn.Module):
def __init__(
self,
args,
args: EncoderConfig,
embed_tokens=None,
embed_positions=None,
output_projection=None,
is_encoder_decoder=False,
is_encoder_decoder: bool = False,
**kwargs
):
self.args = args
Expand Down
3 changes: 2 additions & 1 deletion torchscale/architecture/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import torch.nn as nn

from torchscale.architecture.config import EncoderDecoderConfig
from torchscale.architecture.decoder import Decoder
from torchscale.architecture.encoder import Encoder


class EncoderDecoder(nn.Module):
def __init__(
self,
args,
args: EncoderDecoderConfig,
encoder_embed_tokens=None,
encoder_embed_positions=None,
decoder_embed_tokens=None,
Expand Down
2 changes: 1 addition & 1 deletion torchscale/component/feedforward_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .xmoe.global_groups import get_moe_group


class set_torch_seed(object):
class set_torch_seed:
def __init__(self, seed):
assert isinstance(seed, int)
self.rng_state = self.get_rng_state()
Expand Down