forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMKLDNNConversions.cpp
431 lines (395 loc) · 16.1 KB
/
MKLDNNConversions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
#endif
namespace at { namespace native {
#if AT_MKLDNN_ENABLED()
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
mkldnn_tensor.scalar_type() == ScalarType::BFloat16 ||
mkldnn_tensor.scalar_type() == ScalarType::Byte ||
mkldnn_tensor.scalar_type() == ScalarType::Char,
"mkldnn_to_dense expects float, bfloat16, uint8, int8 tensor input");
ideep::tensor& stensor = itensor_from_mkldnn(mkldnn_tensor);
auto dims = stensor.get_dims();
auto data_type = dtype.has_value() ? dtype.value() : mkldnn_tensor.scalar_type();
TORCH_CHECK(data_type == ScalarType::Float ||
data_type == ScalarType::BFloat16 ||
data_type == ScalarType::Byte ||
data_type == ScalarType::Char,
"mkldnn tensor only can be converted to be a float, bfloat16, uint8, int8 cpu tensor")
if (mkldnn_tensor.scalar_type() == ScalarType::Byte || mkldnn_tensor.scalar_type() == ScalarType::Char) {
// For int8, uint8 input, we should not change the data type.
TORCH_CHECK(mkldnn_tensor.scalar_type() == data_type,
"For int8, uint8 mkldnn_tensor input, we should not change the data type.");
}
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
Tensor cpu_tensor = at::empty(
std::vector<int64_t>(dims.begin(), dims.end()),
mkldnn_tensor.options().layout(c10::kStrided).dtype(data_type));
if (stensor.is_empty()) return cpu_tensor;
auto pub_tensor =
data_type == ScalarType::Float
? stensor.to_public(cpu_tensor.template data_ptr<float>(),
ideep::tensor::data_type::f32)
: (data_type == ScalarType::BFloat16
? stensor.to_public(cpu_tensor.template data_ptr<BFloat16>(),
ideep::tensor::data_type::bf16)
: (data_type == ScalarType::Byte
? stensor.to_public(cpu_tensor.template data_ptr<uint8_t>(),
ideep::tensor::data_type::u8)
: stensor.to_public(cpu_tensor.template data_ptr<int8_t>(),
ideep::tensor::data_type::s8)
)
);
cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
return cpu_tensor.contiguous();
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional<ScalarType> dtype) {
TORCH_CHECK(cpu_tensor.device().is_cpu(),
"dense_to_mkldnn expects CPU tensor input");
TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
"dense_to_mkldnn expects strided tensor input");
TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float ||
cpu_tensor.scalar_type() == ScalarType::BFloat16 ||
cpu_tensor.scalar_type() == ScalarType::Byte ||
cpu_tensor.scalar_type() == ScalarType::Char,
"dense_to_mkldnn expects float, bfloat16, uint8, int8 tensor input");
TORCH_CHECK(cpu_tensor.dim() <= 5,
"Can't convert cpu tensor with the number of dimensions > 5");
// NOTE: forbid direct convert from non-contiguous (or channels last) to `ideep::tensor`.
auto cpu_tensor_cont = cpu_tensor.contiguous();
auto data_type = dtype.has_value() ? dtype.value() : cpu_tensor.scalar_type();
if (cpu_tensor.scalar_type() == ScalarType::Byte || cpu_tensor.scalar_type() == ScalarType::Char) {
// For int8, uint8 input, we should not change the data type.
TORCH_CHECK(cpu_tensor.scalar_type() == data_type,
"For int8, uint8 cpu_tensor input, we should not change the data type.");
}
TORCH_CHECK(data_type == ScalarType::Float ||
data_type == ScalarType::BFloat16 ||
data_type == ScalarType::Byte ||
data_type == ScalarType::Char,
"cpu tensor only can be converted to be a float, bfloat16, uint8, int8 mkldnn tensor")
Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), data_type,
cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(),
cpu_tensor_cont.options().pinned_memory_opt());
ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor);
if (cpu_tensor.scalar_type() == ScalarType::Float) {
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::f32,
(cpu_tensor_cont.template data_ptr<float>()));
} else if (cpu_tensor.scalar_type() == ScalarType::BFloat16) {
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::bf16,
cpu_tensor_cont.template data_ptr<BFloat16>());
} else if (cpu_tensor.scalar_type() == ScalarType::Byte) {
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::u8,
cpu_tensor_cont.template data_ptr<uint8_t>());
} else {
TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Char,
"Expect int8 input of cpu_tensor");
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::s8,
cpu_tensor_cont.template data_ptr<int8_t>());
}
return mkldnn_tensor;
}
// Mkldnn tensor has special non-public format for conv2d weights
// (dense_to_mkldnn only converts dense tensor to mkldnn tensor with
// public format). Ideep conv kernel will do implicit reorder if the
// weight is not already in this optimized format. By the time I'm
// writing this note, we are seeing ~20% perf cost of doing the
// on-the-fly reorder.
Tensor mkldnn_reorder_conv2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_reorder_conv2d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
ideep::dims src_dims = ideep::dims();
bool is_channels_last = false;
auto memory_format = at::MemoryFormat::Contiguous;
if (input_size.has_value()) {
src_dims = input_size.value().vec();
// if has input size, we always use channels last.
is_channels_last = true;
memory_format = at::MemoryFormat::ChannelsLast;
}
auto self_ = self.is_mkldnn() ? self : self.contiguous(memory_format);
auto w = itensor_from_tensor(self_);
// Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
// dimension when groups > 1, having dimension [g, o/g, i, h, w] instead of
// [o, i, h, w]. Ideally we should reorder the weight back in serialization.
// For backward compatibility, we squash the first two dims (g * o/g) back to
// its original form.
if (w.ndims() == 5) {
auto wdims = w.get_dims();
w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]});
}
auto desc = ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
stride_expanded,
padding_expanded,
padding_expanded,
dilation_expanded,
groups,
ideep::algorithm::convolution_direct,
ideep::prop_kind::forward,
w.get_data_type(),
src_dims,
ideep::attr_t(),
is_channels_last);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().device_opt());
}
Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_reorder_conv3d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
const auto padding_expanded = expand_param_if_needed(padding, "padding", 3);
const auto stride_expanded = expand_param_if_needed(stride, "stride", 3);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 3);
auto w = itensor_from_mkldnn(self);
auto desc =
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
stride_expanded,
padding_expanded,
padding_expanded,
dilation_expanded,
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
}
static Tensor mkldnn_reorder_linear_weight(
const Tensor& self,
c10::optional<int64_t> batch_size_opt) {
if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_reorder_linear_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
auto out_features = self.size(0);
auto in_features = self.size(1);
auto self_ = self.contiguous();
auto w = itensor_from_tensor(self_);
ideep::dims input_size;
auto dtype = w.get_data_type();
if (batch_size_opt.has_value()) {
input_size = {batch_size_opt.value(), in_features};
}
auto packed_desc = ideep::inner_product_forward::expected_weights_desc(
{out_features, in_features},
input_size,
/* weight dtype */ dtype,
/* src dtype */ dtype);
ideep::tensor result;
result.init(packed_desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
}
static ideep::tensor::desc get_conv_transpose_expected_weights_desc(
const ideep::tensor::dims& weights_dims,
ideep::tensor::data_type w_dtype,
const ideep::tensor::dims& strides,
const ideep::tensor::dims& padding_l,
const ideep::tensor::dims& padding_r,
const ideep::tensor::dims& dilates,
int groups,
bool channels_last,
ideep::algorithm aalgorithm,
ideep::data_type x_dtype,
const ideep::dims& src_dims) {
if (channels_last) {
return ideep::convolution_transpose_forward::expected_weights_desc<true>(
weights_dims,
w_dtype,
strides,
padding_l,
padding_r,
dilates,
groups,
aalgorithm,
ideep::prop_kind::forward,
src_dims);
} else {
return ideep::convolution_transpose_forward::expected_weights_desc<false>(
weights_dims,
w_dtype,
strides,
padding_l,
padding_r,
dilates,
groups,
aalgorithm,
ideep::prop_kind::forward,
src_dims);
}
}
static Tensor mkldnn_reorder_conv_transpose2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_reorder_conv2d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", 2);
ideep::dims src_dims = ideep::dims();
bool is_channels_last = false;
auto memory_format = at::MemoryFormat::Contiguous;
if (input_size.has_value()) {
src_dims = input_size.value().vec();
// if has input size, we always use channels last.
is_channels_last = true;
memory_format = at::MemoryFormat::ChannelsLast;
}
auto self_ = self.contiguous(memory_format);
ideep::tensor w = itensor_from_tensor(self_);
auto expected_desc = get_conv_transpose_expected_weights_desc(
w.get_dims(),
w.get_data_type(),
stride_expanded,
padding_expanded,
padding_r(padding_expanded, output_padding_expanded),
dilation_expanded,
groups,
is_channels_last,
ideep::algorithm::deconvolution_direct,
w.get_data_type(),
src_dims);
if (groups > 1) {
expected_desc = expected_desc.transpose(1, 2);
} else {
expected_desc = expected_desc.transpose(0, 1);
}
ideep::tensor result;
result.init(expected_desc);
w.transpose_(0, 1);
result.feed_from(w, /*is_deconv_weights*/true);
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().device_opt());
}
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
TORCH_FN(mkldnn_reorder_conv_transpose2d_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_linear_weight"),
TORCH_FN(mkldnn_reorder_linear_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_weight"),
TORCH_FN(mkldnn_reorder_conv2d_weight));
}
#else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional<ScalarType> dtype) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
}
#endif // AT_MKLDNN_ENABLED()
#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
#include <mkl.h>
static Tensor mkl_reorder_linear_weight(
const Tensor& weight,
const int64_t batch_size) {
TORCH_CHECK(
weight.scalar_type() == ScalarType::Float,
"reorder_linear_weight: weight's dtype should be float");
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
auto M = batch_size;
auto N = weight.size(0);
auto K = weight.size(1);
int64_t pack_size =
(int64_t)(cblas_sgemm_pack_get_size(CblasBMatrix, M, N, K) / sizeof(float) + 1);
auto packed_weight = empty_mkldnn(
{pack_size, 1},
weight.scalar_type(),
weight.options().layout_opt(),
weight.options().device_opt(),
weight.options().pinned_memory_opt());
ideep::tensor& mkl_weight = itensor_from_mkldnn(packed_weight);
auto weight_ = weight.contiguous();
const ideep::tensor orig_w = itensor_view_from_dense(weight_);
cblas_sgemm_pack(
CblasRowMajor,
CblasBMatrix,
CblasTrans,
M,
N,
K,
1.0f,
(float*)(orig_w.get_data_handle()),
K,
(float*)(mkl_weight.get_data_handle()));
return packed_weight;
}
TORCH_LIBRARY_IMPL(mkl, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkl::_mkl_reorder_linear_weight"),
TORCH_FN(mkl_reorder_linear_weight));
}
#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
}}