forked from fastai/fastai_dev
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_core.py
553 lines (469 loc) · 19.1 KB
/
torch_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
#AUTOGENERATED! DO NOT EDIT! File to edit: dev/03_torchcore.ipynb (unless otherwise specified).
__all__ = ['progress_bar', 'master_bar', 'tensor', 'set_seed', 'unsqueeze', 'unsqueeze_', 'apply', 'maybe_gather',
'to_detach', 'to_half', 'to_float', 'default_device', 'to_device', 'to_cpu', 'to_np', 'to_concat',
'TensorBase', 'TensorCategory', 'TensorMultiCategory', 'TensorImageBase', 'TensorImage', 'TensorImageBW',
'TensorMask', 'concat', 'Chunks', 'one_param', 'item_find', 'find_device', 'find_bs', 'Module', 'get_model',
'one_hot', 'one_hot_decode', 'params', 'trainable_params', 'bn_types', 'bn_bias_params', 'batch_to_samples',
'logit', 'num_distrib', 'rank_distrib', 'distrib_barrier', 'make_cross_image', 'show_image_batch',
'requires_grad', 'init_default', 'cond_init', 'apply_leaf', 'apply_init', 'set_num_threads',
'ProcessPoolExecutor', 'parallel', 'run_procs', 'parallel_gen', 'script_use_ctx', 'script_save_ctx',
'script_fwd', 'script_bwd', 'grad_module', 'flatten_check']
#Cell
from .test import *
from .core.all import *
from .torch_imports import *
from fastprogress import progress_bar,master_bar
#Cell
if torch.cuda.is_available():
if torch.cuda.current_device()==0:
torch.cuda.set_device(int(os.environ.get('DEFAULT_GPU') or 0))
torch.backends.cudnn.benchmark = True
#Cell
@patch
def __array_eq__(self:Tensor,b):
return torch.equal(self,b) if self.dim() else self==b
#Cell
def _array2tensor(x):
if x.dtype==np.uint16: x = x.astype(np.float32)
return torch.from_numpy(x)
#Cell
def tensor(x, *rest, **kwargs):
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
if len(rest): x = (x,)+rest
# There was a Pytorch bug in dataloader using num_workers>0. Haven't confirmed if fixed
# if isinstance(x, (tuple,list)) and len(x)==0: return tensor(0)
res = (x if isinstance(x, Tensor)
else torch.tensor(x, **kwargs) if isinstance(x, (tuple,list))
else _array2tensor(x) if isinstance(x, ndarray)
else as_tensor(x.values, **kwargs) if isinstance(x, (pd.Series, pd.DataFrame))
else as_tensor(x, **kwargs) if hasattr(x, '__array__') or is_iter(x)
else _array2tensor(array(x), **kwargs))
if res.dtype is torch.float64: return res.float()
return res
#Cell
def set_seed(s):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try: torch.manual_seed(s)
except NameError: pass
try: np.random.seed(s%(2**32-1))
except NameError: pass
random.seed(s)
#Cell
def unsqueeze(x, dim=-1, n=1):
"Same as `torch.unsqueeze` but can add `n` dims"
for _ in range(n): x = x.unsqueeze(dim)
return x
#Cell
def unsqueeze_(x, dim=-1, n=1):
"Same as `torch.unsqueeze_` but can add `n` dims"
for _ in range(n): x.unsqueeze_(dim)
return x
#Cell
def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2(*args, **kwargs))
def _fa_rebuild_qtensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs))
#Cell
def apply(func, x, *args, **kwargs):
"Apply `func` recursively to `x`, passing on args"
if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
if isinstance(x,dict): return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
res = func(x, *args, **kwargs)
return res if x is None else retain_type(res, x)
#Cell
def maybe_gather(x, axis=0):
"Gather copies of `x` on `axis` is training is distributed"
if num_distrib()<=1: return x
ndim = x.ndim
res = [x.new_zeros(*x.shape if ndim > 0 else (1,)) for _ in range(num_distrib())]
torch.distributed.all_gather(res, x if ndim > 0 else x[None])
return torch.cat(res, dim=axis) if ndim > 0 else torch.cat(res, dim=axis).mean()
#Cell
def to_detach(b, cpu=True, gather=True):
"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`."
def _inner(x, cpu=True, gather=True):
if not isinstance(x,Tensor): return x
x = x.detach()
if gather: x = maybe_gather(x)
return x.cpu() if cpu else x
return apply(_inner, b, cpu=cpu, gather=gather)
#Cell
def to_half(b):
"Recursively map lists of tensors in `b ` to FP16."
return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)
#Cell
def to_float(b):
"Recursively map lists of int tensors in `b ` to float."
return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)
#Cell
# None: True if available; True: error if not availabe; False: use CPU
defaults.use_cuda = None
#Cell
def default_device(use_cuda=-1):
"Return or set default device; `use_cuda`: None - CUDA if available; True - error if not availabe; False - CPU"
if use_cuda != -1: defaults.use_cuda=use_cuda
use = defaults.use_cuda or (torch.cuda.is_available() and defaults.use_cuda is None)
assert torch.cuda.is_available() or not use
return torch.device(torch.cuda.current_device()) if use else torch.device('cpu')
#Cell
def to_device(b, device=None):
"Recursively put `b` on `device`."
if device is None: device=default_device()
def _inner(o): return o.to(device, non_blocking=True) if isinstance(o,Tensor) else o
return apply(_inner, b)
#Cell
def to_cpu(b):
"Recursively map lists of tensors in `b ` to the cpu."
return to_device(b,'cpu')
#Cell
def to_np(x):
"Convert a tensor to a numpy array."
return apply(lambda o: o.data.cpu().numpy(), x)
#Cell
def to_concat(xs, dim=0):
"Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
if isinstance(xs[0],dict): return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs.keys()}
#We may receives xs that are not concatenatable (inputs of a text classifier for instance), in this case we return a big list
try: return retain_type(torch.cat(xs, dim=dim), xs[0])
except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0]) for i in range_of(o_)) for o_ in xs], L())
#Cell
class TensorBase(Tensor):
def __new__(cls, x, **kwargs):
res = torch.Tensor._make_subclass(cls, tensor(x))
res._meta = kwargs
return res
def __reduce_ex__(self,proto):
torch.utils.hooks.warn_if_has_hooks(self)
args = (type(self), self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
f = _fa_rebuild_qtensor if self.is_quantized else _fa_rebuild_tensor
return (f, args + (self.requires_grad, OrderedDict()))
def gi(self, i):
res = self[i]
return type(self)(res) if isinstance(res,Tensor) else res
#Cell
def _patch_tb():
if getattr(TensorBase,'_patched',False): return
TensorBase._patched = True
def get_f(fn):
def _f(self, *args, **kwargs):
cls = self.__class__
res = getattr(super(TensorBase, self), fn)(*args, **kwargs)
return cls(res) if isinstance(res,Tensor) else res
return _f
t = tensor([1])
skips = '__getitem__ __class__ __deepcopy__ __delattr__ __dir__ __doc__ __getattribute__ __hash__ __init__ \
__init_subclass__ __new__ __reduce__ __reduce_ex__ __module__ __setstate__'.split()
for fn in dir(t):
if fn in skips: continue
f = getattr(t, fn)
if isinstance(f, (MethodWrapperType, BuiltinFunctionType, BuiltinMethodType, MethodType, FunctionType)):
setattr(TensorBase, fn, get_f(fn))
_patch_tb()
#Cell
class TensorCategory(TensorBase): pass
#Cell
class TensorMultiCategory(TensorCategory): pass
#Cell
class TensorImageBase(TensorBase):
_show_args = ArrayImageBase._show_args
def show(self, ctx=None, **kwargs):
return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})
#Cell
class TensorImage(TensorImageBase): pass
#Cell
class TensorImageBW(TensorImage): _show_args = ArrayImageBW._show_args
#Cell
class TensorMask(TensorImageBase): _show_args = ArrayMask._show_args
#Cell
@patch
def tensored(self:L):
"`mapped(tensor)`"
return self.map(tensor)
@patch
def stack(self:L, dim=0):
"Same as `torch.stack`"
return torch.stack(list(self.tensored()), dim=dim)
@patch
def cat (self:L, dim=0):
"Same as `torch.cat`"
return torch.cat (list(self.tensored()), dim=dim)
#Cell
def concat(*ls):
"Concatenate tensors, arrays, lists, or tuples"
if not len(ls): return []
it = ls[0]
if isinstance(it,torch.Tensor): res = torch.cat(ls)
elif isinstance(it,ndarray): res = np.concatenate(ls)
else:
res = itertools.chain.from_iterable(map(L,ls))
if isinstance(it,(tuple,list)): res = type(it)(res)
else: res = L(res)
return retain_type(res, it)
#Cell
class Chunks:
"Slice and int indexing into a list of lists"
def __init__(self, chunks, lens=None):
self.chunks = chunks
self.lens = L(map(len,self.chunks) if lens is None else lens)
self.cumlens = np.cumsum(0+self.lens)
self.totlen = self.cumlens[-1]
def __getitem__(self,i):
if isinstance(i,slice): return retain_type(self.getslice(i), old=self.chunks[0])
di,idx = self.doc_idx(i)
return retain_type(self.chunks[di][idx], old=self.chunks[0])
def getslice(self, i):
st_d,st_i = self.doc_idx(ifnone(i.start,0))
en_d,en_i = self.doc_idx(ifnone(i.stop,self.totlen+1))
res = [self.chunks[st_d][st_i:(en_i if st_d==en_d else sys.maxsize)]]
for b in range(st_d+1,en_d): res.append(self.chunks[b])
if st_d!=en_d and en_d<len(self.chunks): res.append(self.chunks[en_d][:en_i])
return concat(*res)
def doc_idx(self, i):
if i<0: i=self.totlen+i # count from end
docidx = np.searchsorted(self.cumlens, i+1)-1
cl = self.cumlens[docidx]
return docidx,i-cl
#Cell
def one_param(m):
"First parameter in `m`"
return first(m.parameters())
#Cell
def item_find(x, idx=0):
"Recursively takes the `idx`-th element of `x`"
if is_listy(x): return item_find(x[idx])
if isinstance(x,dict):
key = list(x.keys())[idx] if isinstance(idx, int) else idx
return item_find(x[key])
return x
#Cell
def find_device(b):
"Recursively search the device of `b`."
return item_find(b).device
#Cell
def find_bs(b):
"Recursively search the batch size of `b`."
return item_find(b).shape[0]
#Cell
class Module(nn.Module, metaclass=PrePostInitMeta):
"Same as `nn.Module`, but no need for subclasses to call `super().__init__`"
def __pre_init__(self, *args, **kwargs): super().__init__()
def __init__(self): pass
#Cell
from torch.nn.parallel import DistributedDataParallel
def get_model(model):
"Return the model maybe wrapped inside `model`."
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model
#Cell
def one_hot(x, c):
"One-hot encode `x` with `c` classes."
res = torch.zeros(c, dtype=torch.uint8)
if isinstance(x, Tensor) and x.numel()>0: res[x] = 1.
else: res[list(L(x, use_list=None))] = 1.
return res
#Cell
def one_hot_decode(x, vocab=None):
return L(vocab[i] if vocab else i for i,x_ in enumerate(x) if x_==1)
#Cell
def params(m):
"Return all parameters of `m`"
return [p for p in m.parameters()]
#Cell
def trainable_params(m):
"Return all trainable parameters of `m`"
return [p for p in m.parameters() if p.requires_grad]
#Cell
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
#Cell
def bn_bias_params(m, with_bias=True):
"Return all bias and BatchNorm parameters"
if isinstance(m, bn_types): return L(m.parameters()) if with_bias else L(m.weight)
res = L(m.children()).map(bn_bias_params, with_bias=with_bias).concat()
#if with_bias and hasattr(m, 'bias'): res.append(m.bias)
return res
#Cell
def batch_to_samples(b, max_n=10):
"'Transposes' a batch to (at most `max_n`) samples"
if isinstance(b, Tensor): return retain_types(list(b[:max_n]), [b])
else:
res = L(b).map(partial(batch_to_samples,max_n=max_n))
return retain_types(res.zip(), [b])
#Cell
@patch
def interp_1d(x:Tensor, xp, fp):
"Same as `np.interp`"
slopes = (fp[1:]-fp[:-1])/(xp[1:]-xp[:-1])
incx = fp[:-1] - (slopes*xp[:-1])
locs = (x[:,None]>=xp[None,:]).long().sum(1)-1
locs = locs.clamp(0,len(slopes)-1)
return slopes[locs]*x + incx[locs]
#Cell
def logit(x):
"Logit of `x`, clamped to avoid inf."
x = x.clamp(1e-7, 1-1e-7)
return -(1/x-1).log()
#Cell
def num_distrib():
"Return the number of processes in distributed training (if applicable)."
return int(os.environ.get('WORLD_SIZE', 0))
#Cell
def rank_distrib():
"Return the distributed rank of this process (if applicable)."
return int(os.environ.get('RANK', 0))
#Cell
def distrib_barrier():
"Place a synchronization barrier in distributed training so that ALL sub-processes in the pytorch process group must arrive here before proceeding."
if num_distrib() > 1: torch.distributed.barrier()
#Cell
# Saving arrays requires pytables - optional dependency
try: import tables
except: pass
#Cell
def _comp_filter(lib='lz4',lvl=3): return tables.Filters(complib=f'blosc:{lib}', complevel=lvl)
#Cell
@patch
def save_array(p:Path, o, complib='lz4', lvl=3):
"Save numpy array to a compressed `pytables` file, using compression level `lvl`"
if isinstance(o,Tensor): o = to_np(o)
with tables.open_file(p, mode='w', filters=_comp_filter(lib=complib,lvl=lvl)) as f: f.create_carray('/', 'data', obj=o)
#Cell
@patch
def load_array(p:Path):
"Save numpy array to a `pytables` file"
with tables.open_file(p, 'r') as f: return f.root.data.read()
#Cell
def make_cross_image(bw=True):
"Create a tensor containing a cross image, either `bw` (True) or color"
if bw:
im = torch.zeros(5,5)
im[2,:] = 1.
im[:,2] = 1.
else:
im = torch.zeros(3,5,5)
im[0,2,:] = 1.
im[1,:,2] = 1.
return im
#Cell
def show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs):
"Display batch `b` in a grid of size `items` with `cols` width"
if items<cols: cols=items
rows = (items+cols-1) // cols
if figsize is None: figsize = (cols*3, rows*3)
fig,axs = plt.subplots(rows, cols, figsize=figsize)
for *o,ax in zip(*to_cpu(b), axs.flatten()): show(o, ax=ax, **kwargs)
#Cell
def requires_grad(m):
"Check if the first parameter of `m` requires grad or not"
ps = list(m.parameters())
return ps[0].requires_grad if len(ps)>0 else False
#Cell
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func:
if hasattr(m, 'weight'): func(m.weight)
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
return m
#Cell
def cond_init(m, func):
"Apply `init_default` to `m` unless it's a batchnorm module"
if (not isinstance(m, bn_types)) and requires_grad(m): init_default(m, func)
#Cell
def apply_leaf(m, f):
"Apply `f` to children of `m`."
c = m.children()
if isinstance(m, nn.Module): f(m)
for l in c: apply_leaf(l,f)
#Cell
def apply_init(m, func=nn.init.kaiming_normal_):
"Initialize all non-batchnorm layers of `m` with `func`."
apply_leaf(m, partial(cond_init, func=func))
#Cell
from multiprocessing import Process, Queue
#Cell
def set_num_threads(nt):
"Get numpy (and others) to use `nt` threads"
try: import mkl; mkl.set_num_threads(nt)
except: pass
torch.set_num_threads(1)
os.environ['IPC_ENABLE']='1'
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
os.environ[o] = str(nt)
#Cell
@delegates(concurrent.futures.ProcessPoolExecutor)
class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):
def __init__(self, max_workers=None, on_exc=print, **kwargs):
self.not_parallel = max_workers==0
self.on_exc = on_exc
if self.not_parallel: max_workers=1
super().__init__(max_workers, **kwargs)
def map(self, f, items, *args, **kwargs):
g = partial(f, *args, **kwargs)
if self.not_parallel: return map(g, items)
try: return super().map(g, items)
except Exception as e: self.on_exc(e)
#Cell
def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=True, **kwargs):
"Applies `func` in parallel to `items`, using `n_workers`"
with ProcessPoolExecutor(n_workers) as ex:
r = ex.map(f,items, *args, **kwargs)
if progress:
if total is None: total = len(items)
r = progress_bar(r, total=total, leave=False)
return L(r)
#Cell
def run_procs(f, f_done, args):
"Call `f` for each item in `args` in parallel, yielding `f_done`"
processes = L(args).map(Process, args=arg0, target=f)
for o in processes: o.start()
try: yield from f_done()
except Exception as e: print(e)
finally: processes.map(Self.join())
#Cell
def parallel_gen(cls, items, n_workers=defaults.cpus, as_gen=False, **kwargs):
"Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel."
batches = np.array_split(items, n_workers)
idx = np.cumsum(0 + L(batches).map(len))
queue = Queue()
def f(batch, start_idx):
for i,b in enumerate(cls(**kwargs)(batch)): queue.put((start_idx+i,b))
def done(): return (queue.get() for _ in progress_bar(items, leave=False))
yield from run_procs(f, done, L(batches,idx).zip())
#Cell
def script_use_ctx(f):
"Decorator: create jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs): return sf(*args, *ctx.saved_variables, **kwargs)
return update_wrapper(_f,f)
#Cell
def script_save_ctx(static, *argidx):
"Decorator: create jit script and save args with indices `argidx` using `ctx.save_for_backward`"
def _dec(f):
sf = torch.jit.script(f)
def _f(ctx, *args, **kwargs):
if argidx:
save = [args[o] for o in argidx]
ctx.save_for_backward(*save)
if not argidx: args = [ctx]+args
return sf(*args, **kwargs)
if static: _f = staticmethod(_f)
return update_wrapper(_f,f)
return _dec
#Cell
def script_fwd(*argidx):
"Decorator: create static jit script and save args with indices `argidx` using `ctx.save_for_backward`"
return script_save_ctx(True, *argidx)
#Cell
def script_bwd(f):
"Decorator: create static jit script and pass everything in `ctx.saved_variables to `f`, after `*args`"
return staticmethod(script_use_ctx(f))
#Cell
def grad_module(cls):
"Decorator: convert `cls` into an autograd function"
class _c(nn.Module):
def forward(self, *args, **kwargs): return cls.apply(*args, **kwargs)
return _c
#Comes from 13a_metrics.ipynb, cell
def flatten_check(inp, targ):
"Check that `out` and `targ` have the same number of elements and flatten them."
inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
test_eq(len(inp), len(targ))
return inp,targ