Skip to content

Commit

Permalink
[WIP] Add global var. to set # of OpenMP threads
Browse files Browse the repository at this point in the history
  • Loading branch information
LSchueler committed Dec 30, 2023
1 parent d2da47e commit 2e51a5b
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 16 deletions.
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ For the development version, you can do almost the same:
export GSTOOLS_BUILD_PARALLEL=1
pip install git+git://github.com/GeoStat-Framework/GSTools.git@main
The number of parallel threads can be set with the global variable `config.NUM_THREADS`.

**Using experimental GSTools-Core for even more speed**

You can install the optional dependency `GSTools-Core <https://github.com/GeoStat-Framework/GSTools-Core>`_,
Expand Down
2 changes: 2 additions & 0 deletions src/gstools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
.. currentmodule:: gstools.config
"""
NUM_THREADS = 1

# pylint: disable=W0611
try: # pragma: no cover
import gstools_core
Expand Down
10 changes: 8 additions & 2 deletions src/gstools/field/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def __call__(self, pos, add_nugget=True):
the random modes
"""
pos = np.asarray(pos, dtype=np.double)
summed_modes = summate(self._cov_sample, self._z_1, self._z_2, pos)
summed_modes = summate(
self._cov_sample, self._z_1, self._z_2, pos, config.NUM_THREADS
)
nugget = self.get_nugget(summed_modes.shape) if add_nugget else 0.0
return np.sqrt(self.model.var / self._mode_no) * summed_modes + nugget

Expand Down Expand Up @@ -489,7 +491,11 @@ def __call__(self, pos, add_nugget=True):
"""
pos = np.asarray(pos, dtype=np.double)
summed_modes = summate_incompr(
self._cov_sample, self._z_1, self._z_2, pos
self._cov_sample,
self._z_1,
self._z_2,
pos,
config.NUM_THREADS,
)
nugget = self.get_nugget(summed_modes.shape) if add_nugget else 0.0
e1 = self._create_unit_vector(summed_modes.shape)
Expand Down
8 changes: 5 additions & 3 deletions src/gstools/field/summator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def summate(
const double[:, :] cov_samples,
const double[:] z_1,
const double[:] z_2,
const double[:, :] pos
const double[:, :] pos,
const int num_threads=1,
):
cdef int i, j, d
cdef double phase
Expand All @@ -25,7 +26,7 @@ def summate(

cdef double[:] summed_modes = np.zeros(X_len, dtype=float)

for i in prange(X_len, nogil=True):
for i in prange(X_len, nogil=True, num_threads=num_threads):
for j in range(N):
phase = 0.
for d in range(dim):
Expand All @@ -49,7 +50,8 @@ def summate_incompr(
const double[:, :] cov_samples,
const double[:] z_1,
const double[:] z_2,
const double[:, :] pos
const double[:, :] pos,
const int num_threads=1,
):
cdef int i, j, d
cdef double phase
Expand Down
10 changes: 6 additions & 4 deletions src/gstools/krige/krigesum.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ cimport numpy as np
def calc_field_krige_and_variance(
const double[:, :] krig_mat,
const double[:, :] krig_vecs,
const double[:] cond
const double[:] cond,
const int num_threads=1,
):

cdef int mat_i = krig_mat.shape[0]
Expand All @@ -26,7 +27,7 @@ def calc_field_krige_and_variance(

# error = krig_vecs * krig_mat * krig_vecs
# field = cond * krig_mat * krig_vecs
for k in prange(res_i, nogil=True):
for k in prange(res_i, nogil=True, num_threads=num_threads):
for i in range(mat_i):
krig_fac = 0.0
for j in range(mat_i):
Expand All @@ -40,7 +41,8 @@ def calc_field_krige_and_variance(
def calc_field_krige(
const double[:, :] krig_mat,
const double[:, :] krig_vecs,
const double[:] cond
const double[:] cond,
const int num_threads=1,
):

cdef int mat_i = krig_mat.shape[0]
Expand All @@ -52,7 +54,7 @@ def calc_field_krige(
cdef int i, j, k

# field = cond * krig_mat * krig_vecs
for k in prange(res_i, nogil=True):
for k in prange(res_i, nogil=True, num_threads=num_threads):
for i in range(mat_i):
krig_fac = 0.0
for j in range(mat_i):
Expand Down
17 changes: 12 additions & 5 deletions src/gstools/variogram/estimator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def directional(
const double bandwidth=-1.0, # negative values to turn of bandwidth search
const bint separate_dirs=False, # whether the direction bands don't overlap
str estimator_type='m',
const int num_threads=1,
):
if pos.shape[1] != f.shape[1]:
raise ValueError(f'len(pos) = {pos.shape[1]} != len(f) = {f.shape[1])}')
Expand Down Expand Up @@ -208,7 +209,7 @@ def directional(
cdef int i, j, k, m, d
cdef double dist

for i in prange(i_max, nogil=True):
for i in prange(i_max, nogil=True, num_threads=num_threads):
for j in range(j_max):
for k in range(j+1, k_max):
dist = dist_euclid(dim, pos, j, k)
Expand Down Expand Up @@ -239,6 +240,7 @@ def unstructured(
const double[:, :] pos,
str estimator_type='m',
str distance_type='e',
const int num_threads=1,
):
cdef int dim = pos.shape[0]
cdef _dist_func distance
Expand Down Expand Up @@ -271,7 +273,7 @@ def unstructured(
cdef int i, j, k, m
cdef double dist

for i in prange(i_max, nogil=True):
for i in prange(i_max, nogil=True, num_threads=num_threads):
for j in range(j_max):
for k in range(j+1, k_max):
dist = distance(dim, pos, j, k)
Expand All @@ -287,7 +289,11 @@ def unstructured(
return np.asarray(variogram), np.asarray(counts)


def structured(const double[:, :] f, str estimator_type='m'):
def structured(
const double[:, :] f,
str estimator_type='m',
const int num_threads=1,
):
cdef _estimator_func estimator_func = choose_estimator_func(estimator_type)
cdef _normalization_func normalization_func = (
choose_estimator_normalization(estimator_type)
Expand All @@ -301,7 +307,7 @@ def structured(const double[:, :] f, str estimator_type='m'):
cdef long[:] counts = np.zeros(k_max, dtype=long)
cdef int i, j, k

with nogil, parallel():
with nogil, parallel(num_threads=num_threads):
for i in range(i_max):
for j in range(j_max):
for k in prange(1, k_max-i):
Expand All @@ -316,6 +322,7 @@ def ma_structured(
const double[:, :] f,
const bint[:, :] mask,
str estimator_type='m',
const int num_threads=1,
):
cdef _estimator_func estimator_func = choose_estimator_func(estimator_type)
cdef _normalization_func normalization_func = (
Expand All @@ -330,7 +337,7 @@ def ma_structured(
cdef long[:] counts = np.zeros(k_max, dtype=long)
cdef int i, j, k

with nogil, parallel():
with nogil, parallel(num_threads=num_threads):
for i in range(i_max):
for j in range(j_max):
for k in prange(1, k_max-i):
Expand Down
8 changes: 6 additions & 2 deletions src/gstools/variogram/variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def vario_estimate(
pos,
estimator_type=cython_estimator,
distance_type=distance_type,
num_threads=config.NUM_THREADS,
)
else:
estimates, counts = directional(
Expand All @@ -385,6 +386,7 @@ def vario_estimate(
bandwidth,
separate_dirs=_separate_dirs_test(direction, angles_tol),
estimator_type=cython_estimator,
num_threads=config.num_threads,
)
if dir_no == 1:
estimates, counts = estimates[0], counts[0]
Expand Down Expand Up @@ -485,8 +487,10 @@ def vario_estimate_axis(
cython_estimator = _set_estimator(estimator)

if masked:
return ma_structured(field, mask, cython_estimator)
return structured(field, cython_estimator)
return ma_structured(
field, mask, cython_estimator, num_threads=config.NUM_THREADS
)
return structured(field, cython_estimator, num_threads=config.NUM_THREADS)


# for backward compatibility
Expand Down

0 comments on commit 2e51a5b

Please sign in to comment.