Skip to content

Commit

Permalink
style: add ruff linting to entire repo
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Aug 2, 2024
1 parent 2945d90 commit 9846611
Show file tree
Hide file tree
Showing 30 changed files with 664 additions and 807 deletions.
17 changes: 10 additions & 7 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Set up Python 3.10
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.12"

- name: Install black
run: pip install black
- name: Install pip dependencies
run: pip install black ruff

- name: Run Ruff
run: ruff check .

- name: Check code with black
run: black --check .
Expand All @@ -39,13 +42,13 @@ jobs:
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev
- name: Set up Python 3.10
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.12"

- name: Install requirements
run: pip install .[doc]
run: pip install .[docs]

- name: Build sphinx docs
run: sphinx-build -W -b html doc/ doc/_build/html
124 changes: 48 additions & 76 deletions ch_pipeline/analysis/beam.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
"""Tasks for beam measurement processing"""
"""Tasks for beam measurement processing."""

import json
import yaml
from os import path, listdir
from os import listdir, path

import numpy as np
from scipy import constants

from caput import config, tod, mpiarray, mpiutil
from caput import config, mpiarray, mpiutil, tod
from caput.pipeline import PipelineRuntimeError
from caput.time import STELLAR_S, unix_to_datetime

from ch_ephem import coord, sources
from ch_ephem.observers import chime
from ch_util import tools, layout, holography
from chimedb import data_index as di
from ch_util import holography, tools
from chimedb.core import connect as connect_database


from draco.core import task, io
from draco.util import regrid
from draco.analysis.transform import Regridder
from draco.core import io, task
from draco.core.containers import ContainerBase, SiderealStream, TimeStream, TrackBeam
from draco.util import regrid
from draco.util.tools import invert_no_zero
from scipy import constants

from ..core.containers import TransitFitParams
from .calibration import TransitFit, GainFromTransitFit

from .calibration import TransitFit

SIDEREAL_DAY_SEC = STELLAR_S * 24 * 3600
SPEED_LIGHT = float(constants.c) / 1e6 # 10^6 m / s
Expand Down Expand Up @@ -97,13 +90,13 @@ def process(self, tstream):
Parameters
----------
tstream : TimeStream
timestream data container
Returns
-------
ts : TimeStream
Timestream with containing transits of specified length.
"""

# Redistribute if needed
tstream.redistribute("freq")

Expand All @@ -124,9 +117,7 @@ def process(self, tstream):
tr_time = self.observer.transit_times(self.src, tstream.time[0])
if len(tr_time) != 1:
raise ValueError(
"Didn't find exactly one transit time. Found {:d}.".format(
len(tr_time)
)
f"Didn't find exactly one transit time. Found {len(tr_time):d}."
)
self.cur_transit = tr_time[0]
self._transit_bounds()
Expand Down Expand Up @@ -161,19 +152,18 @@ def append(self, ts):
"""
for dname in ["evec", "eval", "erms"]:
if dname in ts.datasets.keys():
self.log.debug("Stripping dataset {}".format(dname))
self.log.debug(f"Stripping dataset {dname}")
del ts[dname]
self.tstreams.append(ts)

def _finalize_transit(self):
"""Concatenate grouped time streams for the currrent transit."""

# Find where transit starts and ends
if len(self.tstreams) == 0 or self.cur_transit is None:
self.log.info("Did not find any transits.")
return None
self.log.debug(
"Finalising transit for {}...".format(unix_to_datetime(self.cur_transit))
f"Finalising transit for {unix_to_datetime(self.cur_transit)}..."
)
all_t = np.concatenate([ts.time for ts in self.tstreams])
start_ind = int(np.argmin(np.abs(all_t - self.start_t)))
Expand All @@ -185,8 +175,7 @@ def _finalize_transit(self):
dt = self.tstreams[0].time[1] - self.tstreams[0].time[0]
if dt <= 0:
self.log.warning(
"Time steps are not positive definite: dt={:.3f}".format(dt)
+ " Skipping."
f"Time steps are not positive definite: dt={dt:.3f}" + " Skipping."
)
ts = None
if stop_ind - start_ind > int(self.min_span / 360.0 * SIDEREAL_DAY_SEC / dt):
Expand Down Expand Up @@ -221,8 +210,8 @@ def _transit_bounds(self):
"""Find the start and end times of this transit.
Compares the desired HA span to the start and end times of the observation
recorded in the database. Also gets the observation ID."""

recorded in the database. Also gets the observation ID.
"""
# subtract half a day from start time to ensure we don't get following day
self.start_t = self.cur_transit - self.ha_span / 360.0 / 2.0 * SIDEREAL_DAY_SEC
self.end_t = self.cur_transit + self.ha_span / 360.0 / 2.0 * SIDEREAL_DAY_SEC
Expand All @@ -235,9 +224,7 @@ def _transit_bounds(self):
]
if len(this_run) == 0:
self.log.warning(
"Could not find source transit in holography database for {}.".format(
unix_to_datetime(self.cur_transit)
)
f"Could not find source transit in holography database for {unix_to_datetime(self.cur_transit)}."
)
# skip this file
self.cur_transit = None
Expand Down Expand Up @@ -312,7 +299,6 @@ def process(self, data):
new_data : SiderealStream
The regridded data centered on the source RA.
"""

# Redistribute if needed
data.redistribute("freq")

Expand Down Expand Up @@ -556,7 +542,6 @@ def process(self, data, inputmap):
track : TrackBeam
The transit in a beam container.
"""

# redistribute if needed
data.redistribute("freq")

Expand Down Expand Up @@ -584,8 +569,8 @@ def process(self, data, inputmap):
):
msg = (
"Products do not separate into two groups with the length of the input map. "
"({:d}, {:d}) != {:d}"
).format(prod_groups[0].shape[0], prod_groups[1].shape[0], inputs.shape[0])
f"({prod_groups[0].shape[0]:d}, {prod_groups[1].shape[0]:d}) != {inputs.shape[0]:d}"
)
self.log.error(msg)
raise PipelineRuntimeError(msg)

Expand Down Expand Up @@ -714,11 +699,12 @@ def setup(self, tel):
Parameters
----------
tel : TransitTelescope
telescope object to use
"""
self.telescope = io.get_telescope(tel)

def process(self, beam, data):
"""Stack
"""Stack.
Parameters
----------
Expand Down Expand Up @@ -842,15 +828,14 @@ def _resolve_pol(pol1, pol2, pol_axis):

return ipol, ipol

if pol1 == pol2:
ipol1 = pol_axis.index(pol1)
ipol2 = pol_axis.index(pol2)
else:
if pol1 == pol2:
ipol1 = pol_axis.index(pol1)
ipol2 = pol_axis.index(pol2)
else:
ipol1 = pol_axis.index(pol2)
ipol2 = pol_axis.index(pol1)
ipol1 = pol_axis.index(pol2)
ipol2 = pol_axis.index(pol1)

return ipol1, ipol2
return ipol1, ipol2


class HolographyTransitFit(TransitFit):
Expand Down Expand Up @@ -984,7 +969,7 @@ def process(self, transit):


class ApplyHolographyGains(task.SingleTask):
"""Apply gains to a holography transit
"""Apply gains to a holography transit.
Attributes
----------
Expand All @@ -995,11 +980,11 @@ class ApplyHolographyGains(task.SingleTask):
overwrite = config.Property(proptype=bool, default=False)

def process(self, track_in, gain):
"""Apply gain
"""Apply gain.
Parameters
----------
track: draco.core.containers.TrackBeam
track_in: draco.core.containers.TrackBeam
Holography track to apply gains to. Will apply gains to
track['beam'], expecting axes to be freq, pol, input, ha
gain: np.array
Expand All @@ -1010,7 +995,6 @@ def process(self, track_in, gain):
track: draco.core.containers.TrackBeam
Holography track with gains applied.
"""

if self.overwrite:
track = track_in
else:
Expand Down Expand Up @@ -1054,7 +1038,6 @@ class TransitStacker(task.SingleTask):

def setup(self):
"""Initialise internal variables."""

self.stack = None
self.variance = None
self.pseudo_variance = None
Expand All @@ -1068,8 +1051,7 @@ def process(self, transit):
transit: draco.core.containers.TrackBeam
A holography transit.
"""

self.log.info("Weight is %s" % self.weight)
self.log.info(f"Weight is {self.weight}")

if self.stack is None:
self.log.info("Initializing transit stack.")
Expand All @@ -1081,7 +1063,7 @@ def process(self, transit):
self.stack.add_dataset("nsample")
self.stack.redistribute("freq")

self.log.info("Adding %s to stack." % transit.attrs["tag"])
self.log.info(f"Adding {transit.attrs['tag']} to stack.")

# Copy over relevant attributes
self.stack.attrs["filename"] = [transit.attrs["tag"]]
Expand Down Expand Up @@ -1112,14 +1094,12 @@ def process(self, transit):
else:
if list(transit.beam.shape) != list(self.stack.beam.shape):
self.log.error(
"Transit has different shape than stack: {}, {}".format(
transit.beam.shape, self.stack.beam.shape
)
+ " Skipping."
f"Transit has different shape than stack: {transit.beam.shape}, {self.stack.beam.shape}. "
"Skipping."
)
return None
return

self.log.info("Adding %s to stack." % transit.attrs["tag"])
self.log.info(f"Adding {transit.attrs['tag']} to stack.")

self.stack.attrs["filename"].append(transit.attrs["tag"])
self.stack.attrs["observation_id"].append(transit.attrs["observation_id"])
Expand All @@ -1141,7 +1121,7 @@ def process(self, transit):
self.pseudo_variance += coeff * transit.beam[:] ** 2
self.norm += coeff

return None
return

def process_finish(self):
"""Normalise the stack and return the result.
Expand Down Expand Up @@ -1183,8 +1163,9 @@ def process_finish(self):


class FilterHolographyProcessed(task.MPILoggedTask):
"""Filter holography transit DataIntervals produced by `io.QueryDatabase`
to exclude those already processed.
"""Filter holography transit DataIntervals produced by `io.QueryDatabase`.
Excludes DataIntervals which are already processed.
Attributes
----------
Expand All @@ -1201,13 +1182,10 @@ class FilterHolographyProcessed(task.MPILoggedTask):

def setup(self):
"""Get a list of existing processed files."""

# Find processed transit files
self.proc_transits = []
for processed_dir in self.processed_dir:
self.log.debug(
"Looking for processed transits in {}...".format(processed_dir)
)
self.log.debug(f"Looking for processed transits in {processed_dir}...")
# Expand path
processed_dir = path.expanduser(processed_dir)
processed_dir = path.expandvars(processed_dir)
Expand All @@ -1225,7 +1203,7 @@ def setup(self):
obs_id = fh.attrs.get("observation_id", None)
if obs_id is not None:
self.proc_transits.append(obs_id)
self.log.debug("Found {:d} processed transits.".format(len(self.proc_transits)))
self.log.debug(f"Found {len(self.proc_transits):d} processed transits.")

# Query database for observations of this source
hol_obs = None
Expand All @@ -1247,8 +1225,7 @@ def next(self, intervals):
files: list of str
List of files to be processed.
"""

self.log.info("Starting next for task %s" % self.__class__.__name__)
self.log.info(f"Starting next for task {self.__class__.__name__}")

self.comm.Barrier()

Expand All @@ -1266,20 +1243,16 @@ def next(self, intervals):

if len(this_obs) == 0:
self.log.warning(
"Could not find source transit in holography database for {}.".format(
unix_to_datetime(start)
)
f"Could not find source transit in holography database for {unix_to_datetime(start)}."
)
elif this_obs[0].id in self.proc_transits:
self.log.warning(
"Already processed transit for {}. Skipping.".format(
unix_to_datetime(start)
)
f"Already processed transit for {unix_to_datetime(start)}. Skipping."
)
else:
files += fi[0]

self.log.info("Leaving next for task %s" % self.__class__.__name__)
self.log.info(f"Leaving next for task {self.__class__.__name__}")
return files


Expand Down Expand Up @@ -1312,8 +1285,7 @@ def unwrap_lha(lsa, src_ra):


def get_holography_obs(src):
"""Query database for list of all holography observations for the given
source.
"""Query database for list of all holography observations for the given source.
Parameters
----------
Expand All @@ -1327,7 +1299,7 @@ def get_holography_obs(src):
"""
connect_database()
db_src = holography.HolographySource.get(holography.HolographySource.name == src)
db_obs = holography.HolographyObservation.select().where(

return holography.HolographyObservation.select().where(
holography.HolographyObservation.source == db_src
)
return db_obs
Loading

0 comments on commit 9846611

Please sign in to comment.