Skip to content

Commit

Permalink
refactor split scale generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Apr 6, 2023
1 parent 31967ea commit 582f57e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
1 change: 0 additions & 1 deletion draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,6 @@ def flatten_axes(
flat_axes
The names of the flattened axes from slowest to fastest varying.
"""

# Find the relevant axis positions
data_axes = list(dset.attrs["axis"])

Expand Down
36 changes: 22 additions & 14 deletions draco/analysis/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.fft as fft
import pywt

from caput import config
from caput import config, mpiutil

from ..core import containers, task
from .delay import flatten_axes
Expand All @@ -24,6 +24,7 @@ class WaveletSpectrumEstimator(task.SingleTask):
average_axis = config.Property(proptype=str)
ndelay = config.Property(proptype=int, default=128)
wavelet = config.Property(proptype=str, default="morl")
chunks = config.Property(proptype=int, default=4)

def process(
self,
Expand Down Expand Up @@ -81,6 +82,7 @@ def process(
wspec.redistribute("baseline")
dspec.redistribute("baseline")
ws = wspec.spectrum[:].local_array
ww = wspec.weight[:].local_array
ds = dspec.spectrum[:].local_array

# Construct the
Expand All @@ -96,9 +98,12 @@ def process(
d = dset_view.local_array[ii]
w = weight_view.local_array[ii]

# Construct an averaged frequency mask
# Construct an averaged frequency mask and use it to set the output
# weights
Ni = w.mean(axis=0)
ww[ii] = Ni

# Construct a Wiener filter to in-fill the data
D = ds[ii]
Df = (F * D[np.newaxis, :]) @ F.T.conj()
iDf = la.inv(Df)
Expand All @@ -113,17 +118,20 @@ def process(
overwrite_b=True,
).T

with fft.set_workers(workers):
wd, s = pywt.cwt(
d_infill,
scales=wv_scales,
wavelet=self.wavelet,
axis=-1,
sampling_period=df,
method="fft",
)

ws[ii] = wd.var(axis=1)
_fast_tools._fast_var(wd, ws[ii])
# Doing the cwt and calculating the variance can eat a bunch of
# memory. Break it up into chunks to try and control this
for _, s, e in mpiutil.split_m(wv_scales.shape[0], self.chunks).T:
with fft.set_workers(workers):
wd, _ = pywt.cwt(
d_infill,
scales=wv_scales[s:e],
wavelet=self.wavelet,
axis=-1,
sampling_period=df,
method="fft",
)

# Calculate and set the variance
_fast_tools._fast_var(wd, ws[ii, s:e])

return wspec
13 changes: 12 additions & 1 deletion draco/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,13 +2356,24 @@ class WaveletSpectrum(ContainerBase):
"initialise": True,
"distributed": True,
"distributed_axis": "baseline",
}
},
"weight": {
"axes": ["baseline", "freq"],
"dtype": np.float64,
"initialise": True,
"distributed": True,
"distributed_axis": "baseline",
},
}

@property
def spectrum(self):
return self.datasets["spectrum"]

@property
def weight(self):
return self.datasets["weight"]


class Powerspectrum2D(ContainerBase):
"""Container for a 2D cartesian power spectrum.
Expand Down

0 comments on commit 582f57e

Please sign in to comment.