Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
make dynamic scaling default in Float8Linear (#300)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #300

1. makes dynamic scaling default in Float8Linear for an easier migration
   of callsites which currently use Float8DynamicLinear. Fixes
   tests as needed.
2. updates the README to reference Float8Linear for dynamic scaling

Reviewed By: drisspg

Differential Revision: D59305790

fbshipit-source-id: 30d3813946239e0e958e0f7ed446082b578b0607
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 3, 2024
1 parent 4fb0ada commit d4cf2ad
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 19 deletions.
31 changes: 20 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,23 @@ pip install -e ".[dev]"

# User API

We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).

## float8 linear with dynamic scaling
## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`

This is the most accurate recipe as every tensor is scaled dynamically.

```python
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

# create model
m = Model(...)

# convert all `torch.nn.Linear` modules to `Float8DynamicLinear`
swap_linear_with_float8_linear(m, Float8DynamicLinear)
# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, Float8Linear)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
Expand All @@ -54,18 +56,27 @@ m = torch.compile(m)

## float8 linear with delayed scaling

This is theoretically the most performant recipe as it minimizes memory reads.

```python
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

# create model
m = Model(...)

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, Float8Linear)
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
swap_linear_with_float8_linear(
m,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)

# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
Expand Down Expand Up @@ -93,9 +104,7 @@ for _ in range(N_ITER):
# 🧭 Code Organization

* `float8_experimental/float8_linear.py`
- `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py`
- `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
- `Float8Linear` (main user facing entry point for Float8Linear)
* `float8_experimental/float8_tensor.py`
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def swap_linear_with_float8_linear(
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
scaling_type_x: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_w: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
) -> Optional[nn.Module]:
"""
Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`.
Expand Down
16 changes: 14 additions & 2 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ def test_sync_amax_func():
module = torch.nn.Sequential(
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
)
float8_mod = swap_linear_with_float8_linear(module, Float8Linear)
float8_mod = swap_linear_with_float8_linear(
module,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
compiled_swap_func(float8_mod)
assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"
Expand Down Expand Up @@ -329,7 +335,13 @@ def test_sync_amax_func_cuda_graph_success():
my_module = nn.Sequential(
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
).to("cuda")
swap_linear_with_float8_linear(my_module, Float8Linear)
swap_linear_with_float8_linear(
my_module,
Float8Linear,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
inpt = torch.randn(
16, 16, device="cuda", dtype=torch.float32, requires_grad=True
)
Expand Down
9 changes: 8 additions & 1 deletion test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
Expand Down Expand Up @@ -130,7 +132,12 @@ def forward_backward(model, optim, is_fp8, i):
optim.zero_grad()
y_local = model(ref_input_local[i])
y_local.backward(ref_grad_local[i])
if is_fp8:
if is_fp8 and linear_requires_sync(
LinearType.DELAYED,
TensorScalingType.DYNAMIC,
scaling_type_w,
TensorScalingType.DYNAMIC,
):
sync_float8_func(model)
optim.step()
return y_local
Expand Down
11 changes: 9 additions & 2 deletions test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental import config
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -49,7 +49,14 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
)
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
swap_linear_with_float8_linear(
m,
Float8Linear,
emulate=emulate,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
return m


Expand Down

0 comments on commit d4cf2ad

Please sign in to comment.