Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
rebase cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 28, 2024
1 parent 86cde7a commit 9b3eb0f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 23 deletions.
8 changes: 6 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def create_meta_class(
return cls(in_features=in_features, out_features=out_features, bias=False)

def set_mm_configs(self, emulate: bool) -> "Float8DynamicLinear":
self.forward_config = ScaledMMConfig(emulate, not emulate, pad_inner_dim=config.pad_inner_dim)
self.backward_config = ScaledMMConfig(emulate, False, pad_inner_dim=config.pad_inner_dim)
self.forward_config = ScaledMMConfig(
emulate, not emulate, pad_inner_dim=config.pad_inner_dim
)
self.backward_config = ScaledMMConfig(
emulate, False, pad_inner_dim=config.pad_inner_dim
)
return self

def set_weight_and_bias(
Expand Down
20 changes: 0 additions & 20 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,6 @@ def make_float8(data):
return list(out)


@implements([aten.copy_.default])
def copy_fp8(aten_op, args, kwargs=None):
self = args[0]
src = args[1]
assert isinstance(self, Float8Tensor) or isinstance(src, Float8Tensor)

self_data = self
if isinstance(self, Float8Tensor):
self_data = self._data

src_data = src
if isinstance(src, Float8Tensor):
src_data = src._data

fp8_out = aten_op(self_data, src_data, *args[2:], **kwargs)
if isinstance(self, Float8Tensor):
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
return fp8_out


# Errors cant `cat_cuda float8 e4m3fn`
@implements([aten.cat.default])
def float8_cat(aten_op, args, kwargs=None):
Expand Down
1 change: 1 addition & 0 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
Defines an nn module designed to be used during inference
"""

from dataclasses import dataclass

from enum import auto, Enum
Expand Down
2 changes: 1 addition & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_copy_(self):
)
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)

def test_weights_only_load(self):
module = nn.Linear(16, 16)
# Save model state dict
Expand Down

0 comments on commit 9b3eb0f

Please sign in to comment.