forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathParallelCommon.cpp
100 lines (79 loc) · 2.4 KB
/
ParallelCommon.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <ATen/Parallel.h>
#include <ATen/Config.h>
#include <ATen/PTThreadPool.h>
#include <ATen/Version.h>
#include <sstream>
#include <thread>
#ifdef TH_BLAS_MKL
#include <mkl.h>
#endif
#ifdef _OPENMP
#include <omp.h>
#endif
namespace at {
namespace {
const char* get_env_var(
const char* var_name, const char* def_value = nullptr) {
const char* value = std::getenv(var_name);
return value ? value : def_value;
}
size_t get_env_num_threads(const char* var_name, size_t def_value = 0) {
try {
if (auto* value = std::getenv(var_name)) {
int nthreads = std::stoi(value);
TORCH_CHECK(nthreads > 0);
return nthreads;
}
} catch (const std::exception& e) {
std::ostringstream oss;
oss << "Invalid " << var_name << " variable value, " << e.what();
TORCH_WARN(oss.str());
}
return def_value;
}
} // namespace
std::string get_parallel_info() {
std::ostringstream ss;
ss << "ATen/Parallel:\n\tat::get_num_threads() : "
<< at::get_num_threads() << std::endl;
ss << "\tat::get_num_interop_threads() : "
<< at::get_num_interop_threads() << std::endl;
ss << at::get_openmp_version() << std::endl;
#ifdef _OPENMP
ss << "\tomp_get_max_threads() : " << omp_get_max_threads() << std::endl;
#endif
ss << at::get_mkl_version() << std::endl;
#ifdef TH_BLAS_MKL
ss << "\tmkl_get_max_threads() : " << mkl_get_max_threads() << std::endl;
#endif
ss << at::get_mkldnn_version() << std::endl;
ss << "std::thread::hardware_concurrency() : "
<< std::thread::hardware_concurrency() << std::endl;
ss << "Environment variables:" << std::endl;
ss << "\tOMP_NUM_THREADS : "
<< get_env_var("OMP_NUM_THREADS", "[not set]") << std::endl;
ss << "\tMKL_NUM_THREADS : "
<< get_env_var("MKL_NUM_THREADS", "[not set]") << std::endl;
ss << "ATen parallel backend: ";
#if AT_PARALLEL_OPENMP
ss << "OpenMP";
#elif AT_PARALLEL_NATIVE
ss << "native thread pool";
#elif AT_PARALLEL_NATIVE_TBB
ss << "native thread pool and TBB";
#endif
ss << std::endl;
#if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
ss << "Experimental: single thread pool" << std::endl;
#endif
return ss.str();
}
int intraop_default_num_threads() {
size_t nthreads = get_env_num_threads("OMP_NUM_THREADS", 0);
nthreads = get_env_num_threads("MKL_NUM_THREADS", nthreads);
if (nthreads == 0) {
nthreads = TaskThreadPoolBase::defaultNumThreads();
}
return nthreads;
}
} // namespace at