Skip to content

Commit

Permalink
fix blockwise sharding
Browse files Browse the repository at this point in the history
recover offline run script

recover run_interactive

recover regular llama sharding
  • Loading branch information
lsy323 committed Jul 16, 2024
1 parent 663c102 commit 08a1189
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 39 deletions.
36 changes: 18 additions & 18 deletions default_shardings/llama-blockwise-quant.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@


freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 1 # torch.int8 (32000, 4096)
tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wo.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wo.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.attention.wq.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.attention.wk.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.attention.wv.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.feed_forward.w1.weight : 0 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (32, 11008)
layers.*.feed_forward.w2.weight : 2 # torch.int8 (86, 128, 4096)
layers.*.feed_forward.w2.weight_scaler : 1 # torch.bfloat16 (86, 4096)
layers.*.feed_forward.w3.weight : 0 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (32, 11008)
tok_embeddings.weight : -1 # torch.int8 (32000, 4096)
tok_embeddings.weight_scaler : -1 # torch.bfloat16 (4096,)
layers.*.attention.wo.weight : 0 # torch.int8 (32, 128, 4096)
layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (32, 4096)
layers.*.attention.wq.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wq.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.attention.wk.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wk.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.attention.wv.weight : 2 # torch.int8 (32, 128, 4096)
layers.*.attention.wv.weight_scaler : 1 # torch.bfloat16 (32, 4096)
layers.*.feed_forward.w1.weight : 2 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w1.weight_scaler : 1 # torch.bfloat16 (32, 11008)
layers.*.feed_forward.w2.weight : 0 # torch.int8 (86, 128, 4096)
layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (86, 4096)
layers.*.feed_forward.w3.weight : 2 # torch.int8 (32, 128, 11008)
layers.*.feed_forward.w3.weight_scaler : 1 # torch.bfloat16 (32, 11008)
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
norm.weight : -1 # torch.float32 (4096,)
output.weight : 0 # torch.int8 (32, 128, 32000)
output.weight_scaler : 0 # torch.float32 (32, 32000)
output.weight : 2 # torch.int8 (32, 128, 32000)
output.weight_scaler : 1 # torch.float32 (32, 32000)
7 changes: 4 additions & 3 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,9 @@ def _load_from_safetensors(self, path):
if key == "freqs_cis":
continue
weights[key] = f.get_tensor(key)
assert tuple(model_weights.shape) == tuple(
weights[key].shape
), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
# assert tuple(model_weights.shape) == tuple(
# weights[key].shape
# ), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis)
return weights

Expand Down Expand Up @@ -730,6 +730,7 @@ def load_params(self) -> Params:
quantize_linear_weights_scaler_map = (
self.pt_model.get_quantized_linear_weight_to_scaler_map()
)
self.pt_model.process_weight_hook(jax_weights, env=self.env)
with jax.default_device(jax.devices("cpu")[0]):
for key, val in jax_weights.items():
for qname in quantize_linear_weights_scaler_map.keys():
Expand Down
29 changes: 19 additions & 10 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def __init__(
out_features,
bias=False,
device=None,
quant_config=QuantizationConfig(),
env=None,
):
super().__init__()
quant_config = env.quant_config
self.in_features = in_features
self.out_features = out_features

Expand Down Expand Up @@ -175,26 +176,34 @@ def __init__(
out_features,
bias=False,
device=None,
quant_config=QuantizationConfig(),
round_out_features=False,
env=None,
):
super().__init__()
quant_config = env.quant_config
assert (
not quant_config.enable_activation_quantization
), "Activation quantization not supported for blockwise quantized matmul."

self.block_size = quant_config.block_size_weight
self.in_features = in_features
num_partitions = env.mesh.size
if round_out_features and (out_features % (self.block_size * num_partitions)) != 0:
# Make sure out_features is multiple of 128 * num_partitions.
out_features = ((out_features // (self.block_size * num_partitions)) + 1) * (num_partitions * self.block_size)
self.out_features = out_features

n_blocks = in_features // self.block_size
if n_blocks % num_partitions != 0:
n_blocks = ((n_blocks // num_partitions) + 1) * num_partitions

# Use dot general instead of einsum
# Use dot general is slow now.
self.use_dot_general = False
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
# Same perf as non flattened one now.
self.flatten = False

self.block_size = quant_config.block_size_weight
n_blocks = in_features // self.block_size

assert (
not quant_config.enable_activation_quantization
), "Activation quantization not supported for blockwise quantized matmul."

if self.use_dot_general:
weight = torch.ones(
(n_blocks, out_features, self.block_size),
Expand Down Expand Up @@ -516,7 +525,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
linear_kwargs = {"env": env}

self.wo = LinearLayer(
n_heads * self.head_dim,
Expand Down
4 changes: 2 additions & 2 deletions jetstream_pt/third_party/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
)
linear_kwargs = {}
if Linear != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
linear_kwargs = {"env": env}

self.wq = Linear(
hidden_size,
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(
)
linear_kwargs = {}
if Linear != torch.nn.Linear:
linear_kwargs = {"quant_config": env.quant_config}
linear_kwargs = {"env": env}

self.gate_proj = Linear(
hidden_size,
Expand Down
49 changes: 43 additions & 6 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional

import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from jetstream_pt.layers import (
Expand Down Expand Up @@ -41,30 +42,37 @@ def __init__(
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
w1_w3_linear_kwargs = {}
w2_linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs["quant_config"] = env.quant_config
w1_w3_linear_kwargs["env"] = env
w2_linear_kwargs["env"] = env
if env.quant_config.is_blockwise_weight:
# To make w2's n_block is divisible by the number of partitions,
# The out_features of w1/w3 need to round up.
w1_w3_linear_kwargs["round_out_features"] = True


self.w1 = LinearLayer(
dim,
hidden_dim,
bias=False,
device=device,
**linear_kwargs,
**w1_w3_linear_kwargs,
)
self.w2 = LinearLayer(
hidden_dim,
dim,
bias=False,
device=device,
**linear_kwargs,
**w2_linear_kwargs,
)
self.w3 = LinearLayer(
dim,
hidden_dim,
bias=False,
device=device,
**linear_kwargs,
**w1_w3_linear_kwargs,
)

def forward(self, x):
Expand Down Expand Up @@ -179,7 +187,7 @@ def __init__(
LinearLayer = get_quantized_linear_layer(env.quant_config)
linear_kwargs = {}
if LinearLayer != torch.nn.Linear:
linear_kwargs["quant_config"] = env.quant_config
linear_kwargs["env"] = env

self.output = LinearLayer(
params.dim,
Expand Down Expand Up @@ -267,6 +275,35 @@ def get_quantized_embedding_weight_to_scaler_map():
return {
"tok_embeddings.weight": "tok_embeddings.weight_scaler",
}

@staticmethod
def process_weight_hook(jax_weights, env=None):
# Right now we only process weights for blockwise quantization.
# We pad the weights so that the sharded dimension size is divisible by the number of partitions.
quant_config = env.quant_config
num_partitions = env.mesh.size
if quant_config.enable_weight_quantization and quant_config.is_blockwise_weight:
block_size = quant_config.block_size_weight
for k, v in jax_weights.items():
if "w1" in k or "w3" in k:
# Pad w1/w3 to make n_out_channel divisible by num_partitions * block_size.
# This is to make w2's n_block is divisible by the number of partitions.
n_out_channel = v.shape[-1]
multiple_of = block_size * num_partitions
if n_out_channel % (multiple_of) != 0:
n_pad = multiple_of - n_out_channel % (multiple_of)
pad = jnp.zeros(v.shape[:-1] + (n_pad,)).astype(v.dtype)
padded = jnp.concatenate([v, pad], axis=-1)
jax_weights[k] = padded
if "w2" in k:
# Pad w2 to make n_block is divisible by the number of partitions.
n_blocks = v.shape[0]
if n_blocks % num_partitions != 0:
n_pad = num_partitions - n_blocks % (num_partitions)
pad = jnp.zeros((n_pad,) + v.shape[1:]).astype(v.dtype)
padded = jnp.concatenate([v, pad], axis=0)
jax_weights[k] = padded


@staticmethod
def get_weight_sharding_type(model_name: str = ""):
Expand Down

0 comments on commit 08a1189

Please sign in to comment.