-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
105 lines (85 loc) · 2.67 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import numpy as np
import time
import datetime
import pytz
import torch
def init_random_state(seed=0):
# Libraries using GPU should be imported after specifying GPU-ID
import torch
import random
# import dgl
# dgl.seed(seed)
# dgl.random.seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
----------
path : str
Path name
log : bool
Whether to print result for directory creation
"""
import errno
if os.path.exists(path):
return
# print(path)
# path = path.replace('\ ',' ')
# print(path)
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
else:
raise
def get_dir_of_file(f_name):
return os.path.dirname(f_name) + '/'
def init_path(dir_or_file):
path = get_dir_of_file(dir_or_file)
if not os.path.exists(path):
mkdir_p(path)
return dir_or_file
def get_available_devices():
r"""Get IDs of all available GPUs.
Returns:
device (torch.device): Main device (GPU 0 or CPU).
gpu_ids (list): List of IDs of all GPUs that are available.
"""
gpu_ids = []
if torch.cuda.is_available():
gpu_ids += [gpu_id for gpu_id in range(torch.cuda.device_count())]
device = torch.device(f'cuda:{gpu_ids[0]}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
return device, gpu_ids
# * ============================= Time Related =============================
def time2str(t):
if t > 86400:
return '{:.2f}day'.format(t / 86400)
if t > 3600:
return '{:.2f}h'.format(t / 3600)
elif t > 60:
return '{:.2f}min'.format(t / 60)
else:
return '{:.2f}s'.format(t)
def get_cur_time(timezone='Asia/Shanghai', t_format='%m-%d %H:%M:%S'):
return datetime.datetime.fromtimestamp(int(time.time()), pytz.timezone(timezone)).strftime(t_format)
def time_logger(func):
def wrapper(*args, **kw):
start_time = time.time()
print(f'Start running {func.__name__} at {get_cur_time()}')
ret = func(*args, **kw)
print(
f'Finished running {func.__name__} at {get_cur_time()}, running time = {time2str(time.time() - start_time)}.')
return ret
return wrapper