forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfloat8_ops.py
364 lines (314 loc) · 11.3 KB
/
float8_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Tuple
import torch
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_tensor import choose_scaled_mm_config, Float8Tensor
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torch.utils._pytree import tree_map
aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
def implements(aten_ops):
"""Register aten ops to the float8 op table"""
def decorator(func):
for op in aten_ops:
FLOAT8_OPS_TABLE[op] = func
return func
return decorator
@implements(
[
aten.view.default,
aten._unsafe_view.default,
aten.t.default,
aten.as_strided.default,
aten.clone.default,
aten.detach.default,
aten.slice.Tensor,
aten.transpose.int,
aten.fill_.Scalar,
aten.reshape.default,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)
@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)
def make_float8(data):
return Float8Tensor(
data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)
out = map(make_float8, new_data_tensors)
return list(out)
# Errors cant `cat_cuda float8 e4m3fn`
@implements([aten.cat.default])
def float8_cat(aten_op, args, kwargs=None):
chunked_tensors: Tuple[Float8Tensor] = args[0]
orig_dtype = chunked_tensors[0]._orig_dtype
scale = chunked_tensors[0]._scale
mm_config = chunked_tensors[0]._linear_mm_config
fp8_dtype = chunked_tensors[0]._data.dtype
gemm_input_role = chunked_tensors[0]._gemm_input_role
chunk_data = []
for chunk in chunked_tensors:
assert isinstance(
chunk, Float8Tensor
), "Expecting all chunks to be of type Float8Tensor"
assert (
chunk._orig_dtype == orig_dtype
), "Expecting all chunks to be of the same dtype"
assert (
chunk._scale is scale
), "Expecting all chunks to have thee same scale as a result of a split"
assert (
chunk._linear_mm_config is mm_config
), "Expecting all chunks to have thee same mm config as a result of a split"
assert (
chunk._data.dtype == fp8_dtype
), "Expecting all chunks to be of the same dtype as a result of a split"
assert (
chunk._gemm_input_role is gemm_input_role
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
chunk_data.append(chunk._data.view(torch.uint8))
new_data = aten_op(chunk_data, *args[1:], **kwargs)
new_data = new_data.view(fp8_dtype)
return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role)
@implements([aten.sum.dim_IntList])
def float8_cast_up_op(aten_op, args, kwargs=None):
"""Be careful with this function, this is a "fallback" op that
casts the output of the op to the original precision. And performs the op.
We currently need this to support the backward for admmm bias.
"addmm" -> out
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
"""
def unwrap(x):
if isinstance(x, Float8Tensor):
return x.to_original_precision()
return x
new_args = tree_map(unwrap, args)
new_kwargs = tree_map(unwrap, kwargs)
return aten_op(*new_args, **new_kwargs)
def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
a_data = a._data
a_scale = a._scale
b_data = b._data
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._linear_mm_config,
b._gemm_input_role,
b._linear_mm_config,
)
if scaled_mm_config.pad_inner_dim:
assert a._data.size(1) == b._data.size(
0
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)
if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
b_data = b_data.t().contiguous().t()
b_scale = b._scale
return a_data, a_scale, b_data, b_scale
@implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
assert isinstance(a, Float8Tensor) and isinstance(
b, Float8Tensor
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
type(a), type(b)
)
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._linear_mm_config,
b._gemm_input_role,
b._linear_mm_config,
)
if scaled_mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
b_data,
b_scale,
output_dtype,
output_scale=None,
bias=None,
use_fast_accum=scaled_mm_config.use_fast_accum,
)
return tensor_out
@implements([aten.addmm.default])
def float8_addmm(aten_op, args, kwargs=None):
assert (
isinstance(args[0], torch.Tensor)
and isinstance(args[1], Float8Tensor)
and isinstance(args[2], Float8Tensor)
)
bias = args[0]
a = args[1]
b = args[2]
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._linear_mm_config,
b._gemm_input_role,
b._linear_mm_config,
)
if scaled_mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)
return out + bias
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
b_data,
b_scale,
output_dtype,
output_scale=None,
bias=bias,
use_fast_accum=scaled_mm_config.use_fast_accum,
)
return tensor_out
@implements([aten.is_same_size.default])
def float8_is_same_size(aten_op, args, kwargs=None):
return args[0].shape == args[1].shape
@implements([aten._to_copy.default])
def autocast_to_copy(aten_op, args, kwargs=None):
"""This gets called when running matmul under autocast
when the input is a Float8Tensor, presenting as a fp32
tensor.
"""
assert isinstance(args[0], Float8Tensor)
assert (
len(kwargs) == 1 and "dtype" in kwargs
), "Only support dtype kwarg for autocast"
assert kwargs["dtype"] in {
torch.float16,
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ Float8Tensor"
return Float8Tensor(
args[0]._data,
args[0]._scale,
kwargs["dtype"],
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)
@implements(
[
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
]
)
def allgather_fp8(aten_op, args, kwargs=None):
"""
override funcol with FP8 handling
"""
fp8_input = args[0]
assert isinstance(
fp8_input, Float8Tensor
), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
fp8_data = fp8_input._data
fp8_data = fp8_data.contiguous()
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
)
@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
def wait_tensor_fp8(aten_op, args, kwargs=None):
fp8_input = args[0]
assert isinstance(fp8_input, Float8Tensor)
fp8_data = fp8_input._data
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
)
@implements([aten.index_put_.default])
def index_put_fp8(aten_op, args, kwargs=None):
fp8_self = args[0]
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype
fp8_data = fp8_self._data
fp8_values_data = fp8_values._data
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
return Float8Tensor(
fp8_out,
fp8_self._scale,
fp8_self._orig_dtype,
fp8_self._linear_mm_config,
fp8_self._gemm_input_role,
)
@implements([aten.copy_.default])
def copy_fp8(aten_op, args, kwargs=None):
# For a copy op with Float8Tensors involved, only the following combinations are allowed:
# 1. self is a high precision (hp) tensor, src is a Float8Tensor:
# in this case src is upcasted and unscaled to go into the hp tensor
# 2. self and src are Float8Tensors:
# the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat)
# Every other combination is banned as the semantics are not well defined
self = args[0]
src = args[1]
if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
assert (
self._scale == src._scale
), "Expecting both Float8Tensors to have thee same scale"
assert (
self._linear_mm_config == src._linear_mm_config
), "Expecting both Float8Tensors to have thee same mm config"
assert (
self._data.dtype == src._data.dtype
), "Expecting both Float8Tensors to be of the same dtypet"
assert (
self._gemm_input_role == src._gemm_input_role
), "Expecting both Float8Tensors to have the same gemm_input_role"
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
return Float8Tensor(
fp8_out,
self._scale,
self._orig_dtype,
self._linear_mm_config,
self._gemm_input_role,
)
else:
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")