Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel plotting in surfchi2 + update python requires #7

Merged
merged 6 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.10", "3.11"]

steps:
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version}}

- name: Checkout source
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
fetch-depth: 1

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
long_description_content_type="text/markdown",
url="https://github.com/ratt-ru/surfvis",
packages=find_packages(),
python_requires='>=3.7',
python_requires='>=3.10',
install_requires=requirements,
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
3 changes: 2 additions & 1 deletion surfvis/flagchi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,5 @@ def main():
columns=[options.fcol, 'FLAG_ROW'],
rechunk=True)

dask.compute(writes)
with ProgressBar():
dask.compute(writes)
265 changes: 111 additions & 154 deletions surfvis/surfchi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import dask
import dask.array as da
from dask.diagnostics import ProgressBar
from surfvis.utils import surfchisq
from surfvis.utils import surfchisq, surfchisq_plot
from daskms import xds_from_storage_ms as xds_from_ms
from daskms import xds_from_storage_table as xds_from_table
from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr
# might make for cooler histograms but doesn't work out of the box
from astropy.visualization import hist
from pathlib import Path
import concurrent.futures as cf


# COMMAND LINE OPTIONS
Expand Down Expand Up @@ -50,6 +52,10 @@ def create_parser():
def main():
(options,args) = create_parser().parse_args()

print('Input Options:')
for key, value in vars(options).items():
print(' %25s = %s' % (key, value))

if options.dataout == '':
options.dataout = os.getcwd() + '/chi2'

Expand Down Expand Up @@ -111,39 +117,37 @@ def main():

ridx = np.zeros(len(row_chunks))
ridx[1:] = np.cumsum(row_chunks)[0:-1]
rbin_idx.append(da.from_array(ridx.astype(int), chunks=1))
rbin_counts.append(da.from_array(row_chunks, chunks=1))
rbin_idx.append(ridx.astype(int))
rbin_counts.append(row_chunks)

ntime = ut.size
tidx = np.arange(0, ntime, utpc)
tbin_idx.append(da.from_array(tidx.astype(int), chunks=1))
tbin_idx.append(tidx.astype(int))
tidx2 = np.append(tidx, ntime)
tcounts = tidx2[1:] - tidx2[0:-1]
tbin_counts.append(da.from_array(tcounts, chunks=1))
tbin_counts.append(tcounts)

t0 = ut[tidx]
t0s.append(da.from_array(t0, chunks=1))
t0s.append(t0)
tf = ut[tidx + tcounts -1]
tfs.append(da.from_array(tf, chunks=1))
tfs.append(tf)

fidx = np.arange(0, nchan, options.nfreqs)
fbin_idx.append(da.from_array(fidx, chunks=1))
fbin_idx.append(fidx)
fidx2 = np.append(fidx, nchan)
fcounts = fidx2[1:] - fidx2[0:-1]
fbin_counts.append(da.from_array(fcounts, chunks=1))
fbin_counts.append(fcounts)

schema = {}
schema[options.rcol] = {'dims': ('chan', 'corr')}
schema[options.wcol] = {'dims': ('chan', 'corr')}
schema[options.fcol] = {'dims': ('chan', 'corr')}

xds = xds_from_ms(msname,
columns=[options.rcol, options.wcol, options.fcol,
'ANTENNA1', 'ANTENNA2', 'TIME'],
chunks=chunks,
group_cols=['FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'],
table_schema=schema)

columns=[options.rcol, options.wcol, options.fcol,'ANTENNA1', 'ANTENNA2', 'TIME'],
chunks=chunks,
group_cols=['FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'],
table_schema=schema)
if options.use_corrs is None:
print('Using only diagonal correlations')
if len(xds[0].corr) > 1:
Expand All @@ -155,137 +159,89 @@ def main():
print(f"Using correlations {use_corrs}")
ncorr = len(use_corrs)

out_ds = []
idts = []
for i, ds in enumerate(xds):
ds = ds.sel(corr=use_corrs)

resid = ds.get(options.rcol).data
if options.wcol == 'SIGMA_SPECTRUM':
weight = 1.0/ds.get(options.wcol).data**2
else:
weight = ds.get(options.wcol).data
flag = ds.get(options.fcol).data
ant1 = ds.ANTENNA1.data
ant2 = ds.ANTENNA2.data

# ncorr = resid.shape[0]

# time = ds.TIME.values
# utime = np.unique(time)

# spw = xds_from_table(msname + '::SPECTRAL_WINDOW')
# freq = spw[0].CHAN_FREQ.values

field = ds.FIELD_ID
ddid = ds.DATA_DESC_ID
scan = ds.SCAN_NUMBER

tmp = surfchisq(resid, weight, flag, ant1, ant2,
rbin_idx[i], rbin_counts[i],
fbin_idx[i], fbin_counts[i])

d = xr.Dataset(
data_vars={'data': (('time', 'freq', 'corr', 'p', 'q', '2'), tmp),
'fbin_idx': (('freq'), fbin_idx[i]),
'fbin_counts': (('freq'), fbin_counts[i]),
'tbin_idx': (('time'), tbin_idx[i]),
'tbin_counts': (('time'), tbin_counts[i])},
attrs = {'FIELD_ID': ds.FIELD_ID,
'DATA_DESC_ID': ds.DATA_DESC_ID,
'SCAN_NUMBER': ds.SCAN_NUMBER},
# coords={'time': (('time'), utime),
# 'freq': (('freq'), freq),
# 'corr': (('corr'), np.arange(ncorr))}
)

idt = f'::F{ds.FIELD_ID}_D{ds.DATA_DESC_ID}_S{ds.SCAN_NUMBER}'
out_ds.append(xds_to_zarr(d, options.dataout + idt))
idts.append(idt)


dask.compute(out_ds)

# primitive plotting
if options.imagesout is not None:
foldername = options.imagesout.rstrip('/')
if not os.path.isdir(foldername):
os.system('mkdir '+ foldername)

for idt in idts:
xds = xds_from_zarr(options.dataout + idt)
for ds in xds:
field = ds.FIELD_ID
if not os.path.isdir(foldername + f'/field{field}'):
os.system('mkdir '+ foldername + f'/field{field}')

spw = ds.DATA_DESC_ID
if not os.path.isdir(foldername + f'/field{field}' + f'/spw{spw}'):
os.system('mkdir '+ foldername + f'/field{field}' + f'/spw{spw}')

scan = ds.SCAN_NUMBER
if not os.path.isdir(foldername + f'/field{field}' + f'/spw{spw}' + f'/scan{scan}'):
os.system('mkdir '+ foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}')

tmp = ds.data.values
tbin_idx = ds.tbin_idx.values
tbin_counts = ds.tbin_counts.values
fbin_idx = ds.fbin_idx.values
fbin_counts = ds.fbin_counts.values

ntime, nfreq, ncorr, _, _, _ = tmp.shape

basename = foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}/'
if len(os.listdir(basename)):
print(f"Removing contents of {basename} folder")
os.system(f'rm {basename}*.png')
for t in range(ntime):
for f in range(nfreq):
for c in range(ncorr):
chi2 = tmp[t, f, c, :, :, 0]
N = tmp[t, f, c, :, :, 1]
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
t0 = tbin_idx[t]
tf = tbin_idx[t] + tbin_counts[t]
chan0 = fbin_idx[f]
chanf = fbin_idx[f] + fbin_counts[f]
makeplot(chi2_dof, basename + f't{t}_f{f}_c{c}.png',
f't {t0}-{tf}, chan {chan0}-{chanf}, corr {c}')

# reduce over corr
chi2 = np.nansum(tmp[t, f, (0, -1), :, :, 0], axis=0)
N = np.nansum(tmp[t, f, (0, -1), :, :, 1], axis=0)
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
t0 = tbin_idx[t]
tf = tbin_idx[t] + tbin_counts[t]
chan0 = fbin_idx[f]
chanf = fbin_idx[f] + fbin_counts[f]
makeplot(chi2_dof, basename + f't{t}_f{f}.png',
f't {t0}-{tf}, chan {chan0}-{chanf}')

# reduce over freq
chi2 = np.nansum(tmp[t, :, (0, -1), :, :, 0], axis=(0,1))
N = np.nansum(tmp[t, :, (0, -1), :, :, 1], axis=(0,1))
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
t0 = tbin_idx[t]
tf = tbin_idx[t] + tbin_counts[t]
makeplot(chi2_dof, basename + f't{t}.png',
f't {t0}-{tf}')

# now the entire scan
chi2 = np.nansum(tmp[:, :, (0, -1), :, :, 0], axis=(0,1,2))
N = np.nansum(tmp[:, :, (0, -1), :, :, 1], axis=(0,1,2))
chi2_dof = np.zeros_like(chi2)
chi2_dof[N>0] = chi2[N>0]/N[N>0]
chi2_dof[N==0] = np.nan
makeplot(chi2_dof, basename + f'scan.png',
f'scan {scan}.png')
chi2s = {}
counts = {}
futures = []
foldername = options.imagesout.rstrip('/')
with cf.ProcessPoolExecutor(max_workers=options.nthreads) as executor:
for i, ds in enumerate(xds):
field = ds.FIELD_ID
spw = ds.DATA_DESC_ID
scan = ds.SCAN_NUMBER

basename = foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}/'
odir = Path(basename).resolve()
odir.mkdir(parents=True, exist_ok=True)

ntime = tbin_idx[i].size
nfreq = fbin_idx[i].size
ncorr = len(use_corrs)
for t in range(ntime):
for f in range(nfreq):
for c in range(ncorr):
t0 = tbin_idx[i][t]
tf = t0 + tbin_counts[i][t]
chan0 = fbin_idx[i][f]
chanf = chan0 + fbin_counts[i][f]
row0 = rbin_idx[i][t]
rowf = rbin_idx[i][t] + rbin_counts[i][t]
Inu = slice(chan0, chanf)
Irow = slice(row0, rowf)
dso = ds[{'row': Irow, 'chan': Inu}]
# import ipdb; ipdb.set_trace()
dso = dso.sel(corr=use_corrs)
resid = dso.get(options.rcol).data
if options.wcol == 'SIGMA_SPECTRUM':
weight = 1.0/dso.get(options.wcol).data**2
else:
weight = dso.get(options.wcol).data
flag = dso.get(options.fcol).data
ant1 = dso.ANTENNA1.data
ant2 = dso.ANTENNA2.data
t0 = tbin_idx[i][t]
tf = t0 + tbin_counts[i][t]
chan0 = fbin_idx[i][f]
chanf = chan0 + fbin_counts[i][f]
fut = executor.submit(surfchisq_plot, resid, weight, flag, ant1, ant2,
field, spw, scan,
basename + f't{t}_f{f}_c{c}.png',
f't {t0}-{tf}, chan {chan0}-{chanf}, corr {c}')
futures.append(fut)

# to reduce over time, freq and corr at the end
nant = np.maximum(ant1.compute().max(), ant2.compute().max()) + 1
chi2s[f'field{field}_spw{spw}_scan{scan}'] = np.zeros((nant, nant), dtype=float)
counts[f'field{field}_spw{spw}_scan{scan}'] = np.zeros((nant, nant), dtype=float)
print(f"Submitted field{field}_spw{spw}_scan{scan}")

# reduce per scan
num_completed = 0
num_futures = len(futures)
for fut in cf.as_completed(futures):
num_completed += 1
print(f"\rProcessing: {num_completed}/{num_futures}", end='', flush=True)
try:
field, spw, scan, chi2, count = fut.result()
chi2s[f'field{field}_spw{spw}_scan{scan}'] += chi2
counts[f'field{field}_spw{spw}_scan{scan}'] += count
except Exception as e:
raise e

# LB - is it worth doing this in parallel?
print("Plotting per scan")
for key, val in chi2s.items():
field, spw, scan = key.split('_')
field = field.strip('field')
spw = spw.strip('spw')
scan = scan.strip('scan')
count = counts[key]
chi2_dof = np.zeros_like(val)
chi2_dof[count>0] = val[count>0]/count[count>0]
chi2_dof[count<=0] = np.nan

basename = foldername + f'/field{field}' + f'/spw{spw}'+ f'/scan{scan}/'
makeplot(chi2_dof, basename + f'combined.png',
f'scan {scan}.png')

def makeplot(data, name, subt):
nant, _ = data.shape
Expand All @@ -305,14 +261,15 @@ def makeplot(data, name, subt):

rax = divider.append_axes("right", size="50%", pad=0.025)
x = data[~ np.isnan(data)]
hist(x, bins='scott', ax=rax, histtype='stepfilled',
alpha=0.5, density=False)
rax.set_yticks([])
rax.tick_params(axis='y', which='both',
bottom=False, top=False,
labelbottom=False)
rax.tick_params(axis='x', which='both',
length=1, width=1, labelsize=8)
if x.any():
hist(x, bins='scott', ax=rax, histtype='stepfilled',
alpha=0.5, density=False)
rax.set_yticks([])
rax.tick_params(axis='y', which='both',
bottom=False, top=False,
labelbottom=False)
rax.tick_params(axis='x', which='both',
length=1, width=1, labelsize=8)

fig.suptitle(subt, fontsize=20)
plt.savefig(name, dpi=250)
Expand Down
Loading
Loading