-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathoneshot_nas_blocks.py
538 lines (458 loc) · 25.7 KB
/
oneshot_nas_blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
from mxnet.gluon import nn
from mxnet.gluon.nn import HybridBlock
from mxnet import nd
import random
import numpy as np
__all__ = ['ShuffleNetBlock', 'ShuffleNasBlock', 'Activation', 'SE', 'NasBatchNorm', 'NasHybridSequential']
class Activation(HybridBlock):
"""Activation function used in MobileNetV3"""
def __init__(self, act_func, **kwargs):
super(Activation, self).__init__(**kwargs)
if act_func == "relu":
self.act = nn.Activation('relu')
elif act_func == "relu6":
self.act = ReLU6()
elif act_func == "hard_sigmoid":
self.act = HardSigmoid()
elif act_func == "swish":
self.act = nn.Swish()
elif act_func == "hard_swish":
self.act = HardSwish()
elif act_func == "leaky":
self.act = nn.LeakyReLU(alpha=0.375)
else:
raise NotImplementedError
def hybrid_forward(self, F, x):
return self.act(x)
class ReLU6(HybridBlock):
def __init__(self, **kwargs):
super(ReLU6, self).__init__(**kwargs)
def hybrid_forward(self, F, x):
return F.clip(x, 0, 6, name="relu6")
class HardSigmoid(HybridBlock):
def __init__(self, **kwargs):
super(HardSigmoid, self).__init__(**kwargs)
self.act = ReLU6()
def hybrid_forward(self, F, x):
return F.clip(x + 3, 0, 6, name="hard_sigmoid") / 6.
class HardSwish(HybridBlock):
def __init__(self, **kwargs):
super(HardSwish, self).__init__(**kwargs)
self.act = HardSigmoid()
def hybrid_forward(self, F, x):
return x * (F.clip(x + 3, 0, 6, name="hard_swish") / 6.)
class SE(HybridBlock):
def __init__(self, num_in, ratio=4,
act_func=("relu", "hard_sigmoid"), use_bn=False, **kwargs):
super(SE, self).__init__(**kwargs)
def make_divisible(x, divisible_by=8):
# make the mid channel to be divisible to 8 can increase the cache hitting ratio
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
self.use_bn = use_bn
num_out = num_in
num_mid = make_divisible(num_out // ratio)
with self.name_scope():
self.channel_attention = nn.HybridSequential()
self.channel_attention.add(nn.GlobalAvgPool2D(),
nn.Conv2D(channels=num_mid, in_channels=num_in, kernel_size=1, use_bias=True,
prefix='conv_squeeze_'),
Activation(act_func[0]),
nn.Conv2D(channels=num_out, in_channels=num_mid, kernel_size=1, use_bias=True,
prefix='conv_excitation_'),
Activation(act_func[1]))
def hybrid_forward(self, F, x):
out = self.channel_attention(x)
return F.broadcast_mul(x, out)
class ShuffleChannels(HybridBlock):
"""
ShuffleNet channel shuffle Block.
For reshape 0, -1, -2, -3, -4 meaning:
https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html?highlight=reshape#mxnet.ndarray.NDArray.reshape
"""
def __init__(self, mid_channel, groups=2, **kwargs):
super(ShuffleChannels, self).__init__()
# For ShuffleNet v2, groups is always set 2
assert groups == 2
self.groups = groups
self.mid_channel = mid_channel
def hybrid_forward(self, F, x, *args, **kwargs):
# batch_size, channels, height, width = x.shape
# assert channels % 2 == 0
# mid_channels = channels // 2
data = F.reshape(x, shape=(0, -4, self.groups, -1, -2))
data = F.swapaxes(data, 1, 2)
data = F.reshape(data, shape=(0, -3, -2))
data_project = F.slice(data, begin=(None, None, None, None), end=(None, self.mid_channel, None, None))
data_x = F.slice(data, begin=(None, self.mid_channel, None, None), end=(None, None, None, None))
return data_project, data_x
class ShuffleChannelsConv(HybridBlock):
"""
ShuffleNet channel shuffle Block.
For reshape 0, -1, -2, -3, -4 meaning:
https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html?highlight=reshape#mxnet.ndarray.NDArray.reshape
"""
def __init__(self, mid_channel, groups=2, **kwargs):
super(ShuffleChannelsConv, self).__init__()
# For ShuffleNet v2, groups is always set 2
assert groups == 2
self.groups = groups
self.mid_channel = int(mid_channel)
self.channels = int(mid_channel * 2)
self.transpose_conv = nn.Conv2D(self.channels, in_channels=self.channels, kernel_size=1, strides=1,
padding=0, use_bias=False, prefix='transpose_conv_')
def transpose_init(self):
for i, param in enumerate(self.transpose_conv.collect_params().values()):
if i > 0:
raise ValueError('Transpose conv should only have the weights parameter.')
param.set_data(self.generate_transpose_conv_kernel())
param.grad_req = 'null'
def generate_transpose_conv_kernel(self):
c = self.channels
if c % 2 != 0:
raise ValueError('Channel number should be even.')
idx = np.zeros(c)
idx[np.arange(0, c, 2)] = np.arange(c / 2)
idx[np.arange(1, c, 2)] = np.arange(c / 2, c, 1)
weights = np.zeros((c, c))
weights[np.arange(c), idx.astype(int)] = 1.0
print(weights)
return nd.expand_dims(nd.expand_dims(nd.array(weights), axis=2), axis=3)
def hybrid_forward(self, F, x, *args, **kwargs):
data = self.transpose_conv(x)
data_project = F.slice(data, begin=(None, None, None, None), end=(None, self.mid_channel, None, None))
data_x = F.slice(data, begin=(None, self.mid_channel, None, None), end=(None, None, None, None))
return data_project, data_x
class ChannelSelector(HybridBlock):
"""
Random channel # selection
"""
def __init__(self, channel_number):
super(ChannelSelector, self).__init__()
self.channel_number = channel_number
def hybrid_forward(self, F, x, block_channel_mask, *args, **kwargs):
block_channel_mask = F.slice(block_channel_mask, begin=(None, None), end=(None, self.channel_number))
block_channel_mask = F.reshape(block_channel_mask, shape=(1, self.channel_number, 1, 1))
x = F.broadcast_mul(x, block_channel_mask)
return x
class ShuffleNetBlock(HybridBlock):
def __init__(self, input_channel, output_channel, mid_channel, ksize, stride, shuffle_method=ShuffleChannels,
block_mode='ShuffleNetV2', fix_arch=True, bn=nn.BatchNorm, act_name='relu', use_se=False, **kwargs):
super(ShuffleNetBlock, self).__init__()
assert stride in [1, 2]
assert ksize in [3, 5, 7]
assert block_mode in ['ShuffleNetV2', 'ShuffleXception']
self.stride = stride
self.ksize = ksize
self.padding = self.ksize // 2
self.block_mode = block_mode
self.input_channel = input_channel
self.output_channel = output_channel
# project_input_C == project_mid_C == project_output_C == main_input_channel
self.project_channel = input_channel // 2 if stride == 1 else input_channel
# stride 1, input will be split
self.main_input_channel = input_channel // 2 if stride == 1 else input_channel
self.main_mid_channel = mid_channel
self.main_output_channel = output_channel - self.project_channel
self.fix_arch = fix_arch
with self.name_scope():
"""
Regular block: (We usually have the down-sample block first, then followed by repeated regular blocks)
Input[64] -> split two halves -> main branch: [32] --> mid_channels (final_output_C[64] // 2 * scale[1.4])
| |--> main_out_C[32] (final_out_C (64) - input_C[32]
|
|-----> project branch: [32], do nothing on this half
Concat two copies: [64 - 32] + [32] --> [64] for final output channel
=====================================================================
In "Single path one shot nas" paper, Channel Search is searching for the main branch intermediate #channel.
And the mid channel is controlled / selected by the channel scales (0.2 ~ 2.0), calculated from:
mid channel = block final output # channel // 2 * scale
Since scale ~ (0, 2), this is guaranteed: main mid channel < final output channel
"""
if stride == 1:
self.channel_shuffle_and_split = shuffle_method(mid_channel=input_channel // 2, groups=2)
self.main_branch = nn.HybridSequential() if fix_arch else NasBaseHybridSequential()
if block_mode == 'ShuffleNetV2':
self.main_branch.add(
# pw
nn.Conv2D(self.main_mid_channel, in_channels=self.main_input_channel, kernel_size=1, strides=1,
padding=0, use_bias=False))
if not fix_arch:
self.main_branch.add(ChannelSelector(channel_number=self.main_mid_channel))
self.main_branch.add(
bn(in_channels=self.main_mid_channel, momentum=0.1),
Activation(act_name),
# dw with linear output
nn.Conv2D(self.main_mid_channel, in_channels=self.main_mid_channel, kernel_size=self.ksize,
strides=self.stride, padding=self.padding, groups=self.main_mid_channel, use_bias=False),
bn(in_channels=self.main_mid_channel, momentum=0.1),
# pw
nn.Conv2D(self.main_output_channel, in_channels=self.main_mid_channel, kernel_size=1, strides=1,
padding=0, use_bias=False),
bn(in_channels=self.main_output_channel, momentum=0.1),
Activation(act_name)
)
elif block_mode == 'ShuffleXception':
self.main_branch.add(
# dw with linear output
nn.Conv2D(self.main_input_channel, in_channels=self.main_input_channel, kernel_size=self.ksize,
strides=self.stride, padding=self.padding, groups=self.main_input_channel, use_bias=False),
bn(in_channels=self.main_input_channel, momentum=0.1),
# pw
nn.Conv2D(self.main_mid_channel, in_channels=self.main_input_channel, kernel_size=1, strides=1,
padding=0, use_bias=False))
if not fix_arch:
self.main_branch.add(ChannelSelector(channel_number=self.main_mid_channel))
self.main_branch.add(
bn(in_channels=self.main_mid_channel, momentum=0.1),
Activation(act_name),
# dw with linear output
nn.Conv2D(self.main_mid_channel, in_channels=self.main_mid_channel, kernel_size=self.ksize,
strides=1, padding=self.padding, groups=self.main_mid_channel, use_bias=False),
bn(in_channels=self.main_mid_channel, momentum=0.1),
# pw
nn.Conv2D(self.main_mid_channel, in_channels=self.main_mid_channel, kernel_size=1, strides=1,
padding=0, use_bias=False))
if not fix_arch:
self.main_branch.add(ChannelSelector(channel_number=self.main_mid_channel))
self.main_branch.add(
bn(in_channels=self.main_mid_channel, momentum=0.1),
Activation(act_name),
# dw with linear output
nn.Conv2D(self.main_mid_channel, in_channels=self.main_mid_channel, kernel_size=self.ksize,
strides=1, padding=self.padding, groups=self.main_mid_channel, use_bias=False),
bn(in_channels=self.main_mid_channel, momentum=0.1),
# pw
nn.Conv2D(self.main_output_channel, in_channels=self.main_mid_channel, kernel_size=1, strides=1,
padding=0, use_bias=False),
bn(in_channels=self.main_output_channel, momentum=0.1),
Activation(act_name)
)
if use_se:
self.main_branch.add(SE(self.main_output_channel))
if self.stride == 2:
"""
Down-sample block:
Input[16] -> two copies -> main branch: [16] --> mid_channels (final_output_C[64] // 2 * scale[1.4])
| |--> main_out_C[48] (final_out_C (64) - input_C[16])
|
|-----> project branch: [16] --> project_mid_C[16] --> project_out_C[16]
Concat two copies: [64 - 16] + [16] --> [64] for final output channel
"""
self.proj_branch = nn.HybridSequential()
self.proj_branch.add(
# dw with linear output
nn.Conv2D(self.project_channel, in_channels=self.project_channel, kernel_size=self.ksize,
strides=stride, padding=self.padding, groups=self.project_channel, use_bias=False),
bn(in_channels=self.project_channel, momentum=0.1),
# pw
nn.Conv2D(self.project_channel, in_channels=self.project_channel, kernel_size=1, strides=1,
padding=0, use_bias=False),
bn(in_channels=self.project_channel, momentum=0.1),
Activation(act_name)
)
def hybrid_forward(self, F, old_x, *args, **kwargs):
if self.stride == 2:
x_project = old_x
x = old_x
return F.concat(self.proj_branch(x_project), self.main_branch(x), dim=1)
elif self.stride == 1:
x_project, x = self.channel_shuffle_and_split(old_x)
return F.concat(x_project, self.main_branch(x), dim=1)
class ShuffleNetCSBlock(ShuffleNetBlock):
"""
ShuffleNetBlock with Channel Selecting
"""
def __init__(self, input_channel, output_channel, mid_channel, ksize, stride,
block_mode='ShuffleNetV2', fix_arch=False, bn=nn.BatchNorm, act_name='relu', use_se=False, **kwargs):
super(ShuffleNetCSBlock, self).__init__(input_channel, output_channel, mid_channel, ksize, stride,
block_mode=block_mode, fix_arch=fix_arch, bn=bn,
act_name=act_name, use_se=use_se, **kwargs)
def hybrid_forward(self, F, old_x, channel_choice, *args, **kwargs):
if self.stride == 2:
x_project = old_x
x = old_x
return F.concat(self.proj_branch(x_project), self.main_branch(x, channel_choice), dim=1)
elif self.stride == 1:
x_project, x = self.channel_shuffle_and_split(old_x)
return F.concat(x_project, self.main_branch(x, channel_choice), dim=1)
class ShuffleNasBlock(HybridBlock):
def __init__(self, input_channel, output_channel, stride, max_channel_scale=2.0,
use_all_blocks=False, bn=nn.BatchNorm, act_name='relu', use_se=False, **kwargs):
super(ShuffleNasBlock, self).__init__()
assert stride in [1, 2]
self.use_all_blocks = use_all_blocks
with self.name_scope():
"""
Four pre-defined blocks
"""
max_mid_channel = make_divisible(int(output_channel // 2 * max_channel_scale))
self.block_sn_3x3 = ShuffleNetCSBlock(input_channel, output_channel, max_mid_channel,
3, stride, 'ShuffleNetV2', bn=bn, act_name=act_name, use_se=use_se)
self.block_sn_5x5 = ShuffleNetCSBlock(input_channel, output_channel, max_mid_channel,
5, stride, 'ShuffleNetV2', bn=bn, act_name=act_name, use_se=use_se)
self.block_sn_7x7 = ShuffleNetCSBlock(input_channel, output_channel, max_mid_channel,
7, stride, 'ShuffleNetV2', bn=bn, act_name=act_name, use_se=use_se)
self.block_sx_3x3 = ShuffleNetCSBlock(input_channel, output_channel, max_mid_channel,
3, stride, 'ShuffleXception', bn=bn, act_name=act_name, use_se=use_se)
def hybrid_forward(self, F, x, block_choice, block_channel_mask, *args, **kwargs):
# ShuffleNasBlock has three inputs and passes two inputs to the ShuffleNetCSBlock
if self.use_all_blocks:
temp1 = self.block_sn_3x3(x, block_channel_mask)
temp2 = self.block_sn_5x5(x, block_channel_mask)
temp3 = self.block_sn_7x7(x, block_channel_mask)
temp4 = self.block_sx_3x3(x, block_channel_mask)
x = (temp1 + temp2 + temp3 + temp4) / 4
else:
if block_choice == 0:
x = self.block_sn_3x3(x, block_channel_mask)
elif block_choice == 1:
x = self.block_sn_5x5(x, block_channel_mask)
elif block_choice == 2:
x = self.block_sn_7x7(x, block_channel_mask)
elif block_choice == 3:
x = self.block_sx_3x3(x, block_channel_mask)
return x
class NasBaseHybridSequential(nn.HybridSequential):
def __init__(self, prefix=None, params=None):
super(NasBaseHybridSequential, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, x, block_channel_mask, *args, **kwargs):
for block in self._children.values():
if isinstance(block, ChannelSelector):
x = block(x, block_channel_mask)
else:
x = block(x)
return x
class NasHybridSequential(nn.HybridSequential):
def __init__(self, prefix=None, params=None):
super(NasHybridSequential, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, x, full_arch, full_channel_mask):
nas_index = 0
base_index = 0
for block in self._children.values():
if isinstance(block, ShuffleNasBlock):
block_choice = F.slice(full_arch, begin=nas_index, end=nas_index + 1)
block_channel_mask = F.slice(full_channel_mask, begin=(nas_index, None), end=(nas_index + 1, None))
x = block(x, block_choice, block_channel_mask)
nas_index += 1
elif isinstance(block, ShuffleNetBlock):
block_channel_mask = F.slice(full_channel_mask, begin=(base_index, None), end=(base_index + 1, None))
x = block(x, block_channel_mask)
base_index += 1
else:
x = block(x)
# assert (nas_index == full_arch.shape[0] == full_channel_mask.shape[0] or
# base_index == full_arch.shape[0] == full_channel_mask.shape[0])
return x
class NasBatchNorm(HybridBlock):
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, inference_update_stat=False, **kwargs):
super(NasBatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats}
self.inference_update_stat = inference_update_stat
if in_channels != 0:
self.in_channels = in_channels
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
shape=(in_channels,),
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
shape=(in_channels,),
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)
self.momentum = nd.array([self._kwargs['momentum']])
self.momentum_rest = nd.array([1 - self._kwargs['momentum']])
def cast(self, dtype):
if np.dtype(dtype).name == 'float16':
dtype = 'float32'
super(NasBatchNorm, self).cast(dtype)
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
if self.inference_update_stat:
# TODO: for multi gpu, generate ndarray.array and do multiplication
mean = x.mean(axis=(0, 2, 3))
mean_expanded = F.expand_dims(F.expand_dims(F.expand_dims(mean, axis=0), axis=2), axis=3)
var = F.square(F.broadcast_minus(x, mean_expanded)).mean(axis=(0, 2, 3))
# TODO: remove debug codes
# print("Passed running_mean: {}, raw running_mean: {}".format(running_mean, self.running_mean.data()))
# print("Passed running_var: {}, raw running_var: {}".format(running_var, self.running_var.data()))
# print("Passed gamme: {}, beta: {}".format(gamma, beta))
# var_expanded = F.expand_dims(F.expand_dims(F.expand_dims(var, axis=0), axis=2), axis=3)
# normalized_x = (x - mean_expanded) / F.sqrt(var_expanded)
# print("Calculated mean: {}".format(mean))
# print("Calculated var: {}".format(var))
# print("Normalized x: {}".format(normalized_x))
# rst = (x - mean_expanded) / F.sqrt(var_expanded) * \
# F.expand_dims(F.expand_dims(F.expand_dims(gamma, axis=0), axis=2), axis=3) + \
# F.expand_dims(F.expand_dims(F.expand_dims(beta, axis=0), axis=2), axis=3)
# print("Target rst: {}".format(rst))
# update running mean and var
running_mean = F.add(F.multiply(self.running_mean.data(), self.momentum.as_in_context(x.context)),
F.multiply(mean, self.momentum_rest.as_in_context(x.context)))
running_var = F.add(F.multiply(self.running_var.data(), self.momentum.as_in_context(x.context)),
F.multiply(var, self.momentum_rest.as_in_context(x.context)))
self.running_mean.set_data(running_mean)
self.running_var.set_data(running_var)
return F.BatchNorm(x, gamma, beta, mean, var, name='fwd', **self._kwargs)
else:
return F.BatchNorm(x, gamma, beta, running_mean, running_var, name='fwd', **self._kwargs)
def __repr__(self):
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels if in_channels else None)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
def make_divisible(x, divisible_by=8):
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
def random_block_choices(stage_repeats=None, num_of_block_choices=4):
if stage_repeats is None:
stage_repeats = [4, 4, 8, 4]
block_number = sum(stage_repeats)
block_choices = []
for i in range(block_number):
block_choices.append(random.randint(0, num_of_block_choices - 1))
return nd.array(block_choices)
def random_channel_mask(stage_repeats=None, stage_out_channels=None, candidate_scales=None,
select_all_channels=False):
"""
candidate_scales = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]
"""
if stage_repeats is None:
stage_repeats = [4, 4, 8, 4]
if stage_out_channels is None:
stage_out_channels = [64, 160, 320, 640]
if candidate_scales is None:
candidate_scales = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]
assert len(stage_repeats) == len(stage_out_channels)
channel_mask = []
global_max_length = int(stage_out_channels[-1] // 2 * candidate_scales[-1])
for i in range(len(stage_out_channels)):
local_max_length = int(stage_out_channels[i] // 2 * candidate_scales[-1])
local_min_length = int(stage_out_channels[i] // 2 * candidate_scales[0])
for _ in range(stage_repeats[i]):
if select_all_channels:
local_mask = [1] * global_max_length
else:
local_mask = [0] * global_max_length
random_select_channel = random.randint(local_min_length, local_max_length)
for j in range(random_select_channel):
local_mask[j] = 1
channel_mask.append(local_mask)
return nd.array(channel_mask)
def main():
print("If you want to verify these modules with test cases, please use this tool: Single-Path-One-Shot-Nas-MXNet/utils/test_cases.py")
if __name__ == '__main__':
main()