From 582f57ec39bc8d7d75d25558265d43d215587b0f Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Thu, 6 Apr 2023 16:34:57 -0700 Subject: [PATCH] refactor split scale generation --- draco/analysis/delay.py | 1 - draco/analysis/wavelet.py | 36 ++++++++++++++++++++++-------------- draco/core/containers.py | 13 ++++++++++++- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index f8f4edca4..540aefd98 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -1248,7 +1248,6 @@ def flatten_axes( flat_axes The names of the flattened axes from slowest to fastest varying. """ - # Find the relevant axis positions data_axes = list(dset.attrs["axis"]) diff --git a/draco/analysis/wavelet.py b/draco/analysis/wavelet.py index f68d5d421..ae868b008 100644 --- a/draco/analysis/wavelet.py +++ b/draco/analysis/wavelet.py @@ -6,7 +6,7 @@ import scipy.fft as fft import pywt -from caput import config +from caput import config, mpiutil from ..core import containers, task from .delay import flatten_axes @@ -24,6 +24,7 @@ class WaveletSpectrumEstimator(task.SingleTask): average_axis = config.Property(proptype=str) ndelay = config.Property(proptype=int, default=128) wavelet = config.Property(proptype=str, default="morl") + chunks = config.Property(proptype=int, default=4) def process( self, @@ -81,6 +82,7 @@ def process( wspec.redistribute("baseline") dspec.redistribute("baseline") ws = wspec.spectrum[:].local_array + ww = wspec.weight[:].local_array ds = dspec.spectrum[:].local_array # Construct the @@ -96,9 +98,12 @@ def process( d = dset_view.local_array[ii] w = weight_view.local_array[ii] - # Construct an averaged frequency mask + # Construct an averaged frequency mask and use it to set the output + # weights Ni = w.mean(axis=0) + ww[ii] = Ni + # Construct a Wiener filter to in-fill the data D = ds[ii] Df = (F * D[np.newaxis, :]) @ F.T.conj() iDf = la.inv(Df) @@ -113,17 +118,20 @@ def process( overwrite_b=True, ).T - with fft.set_workers(workers): - wd, s = pywt.cwt( - d_infill, - scales=wv_scales, - wavelet=self.wavelet, - axis=-1, - sampling_period=df, - method="fft", - ) - - ws[ii] = wd.var(axis=1) - _fast_tools._fast_var(wd, ws[ii]) + # Doing the cwt and calculating the variance can eat a bunch of + # memory. Break it up into chunks to try and control this + for _, s, e in mpiutil.split_m(wv_scales.shape[0], self.chunks).T: + with fft.set_workers(workers): + wd, _ = pywt.cwt( + d_infill, + scales=wv_scales[s:e], + wavelet=self.wavelet, + axis=-1, + sampling_period=df, + method="fft", + ) + + # Calculate and set the variance + _fast_tools._fast_var(wd, ws[ii, s:e]) return wspec diff --git a/draco/core/containers.py b/draco/core/containers.py index 59df8a3ac..458566f09 100644 --- a/draco/core/containers.py +++ b/draco/core/containers.py @@ -2356,13 +2356,24 @@ class WaveletSpectrum(ContainerBase): "initialise": True, "distributed": True, "distributed_axis": "baseline", - } + }, + "weight": { + "axes": ["baseline", "freq"], + "dtype": np.float64, + "initialise": True, + "distributed": True, + "distributed_axis": "baseline", + }, } @property def spectrum(self): return self.datasets["spectrum"] + @property + def weight(self): + return self.datasets["weight"] + class Powerspectrum2D(ContainerBase): """Container for a 2D cartesian power spectrum.