forked from mila-iqia/covid_p2p_risk_prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
32 lines (25 loc) · 1.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Sequence
import torch
def to_device(x, device):
if torch.is_tensor(x):
return x.to(device)
elif isinstance(x, dict):
return type(x)({key: to_device(val, device) for key, val in x.items()})
elif isinstance(x, (list, tuple)):
return type(x)([to_device(item, device) for item in x])
elif isinstance(x, torch.nn.Module):
return x.to(device)
else:
raise NotImplementedError
def momentum_accumulator(momentum):
def _accumulator(old, new):
return momentum * old + (1 - momentum) * new
return _accumulator
def thermometer_encoding(x: torch.Tensor, value_range: Sequence[int], size: int):
assert x.shape[-1] == 1
# Make linspace and expand it to shape (1, ..., 1, size), with trailing n-1
# singleton dimensions, where x.ndim = n.
expanded_linspace = torch.linspace(
value_range[0], value_range[1], size, dtype=x.dtype, device=x.device
).expand(*([1] * (x.dim() - 1) + [size]))
return torch.gt(x, expanded_linspace).float()