From 8c7e3d019ac20df6a33ba0075af52a5c31214743 Mon Sep 17 00:00:00 2001 From: ketiltrout Date: Fri, 2 Aug 2024 11:15:14 -0700 Subject: [PATCH] ruffage --- .github/workflows/main.yml | 19 +- ch_util/_db_tables.py | 200 ++++++------ ch_util/andata.py | 216 +++++++------ ch_util/cal_utils.py | 172 ++++++----- ch_util/chan_monitor.py | 68 ++--- ch_util/connectdb.py | 32 +- ch_util/data_index.py | 211 ++++++------- ch_util/data_quality.py | 65 ++-- ch_util/ephemeris.py | 71 +++-- ch_util/finder.py | 115 +++---- ch_util/fluxcat.py | 100 +++--- ch_util/hfbcat.py | 9 +- ch_util/holography.py | 116 +++---- ch_util/layout.py | 284 +++++++++--------- ch_util/ni_utils.py | 54 ++-- ch_util/plot.py | 63 ++-- ch_util/rfi.py | 14 +- ch_util/timing.py | 271 +++++++++-------- ch_util/tools.py | 208 ++++++------- doc/conf.py | 1 - pyproject.toml | 6 + scripts/scan2txt.py | 100 +++--- .../generate_archive_test_data_2_X.py | 2 +- .../generate_archive_test_data_3_X.py | 2 +- scripts/update_psrcat.py | 16 +- tests/test_andata.py | 4 +- tests/test_andata_archive2.py | 7 +- tests/test_andata_archive3.py | 6 +- tests/test_andata_dist.py | 4 +- 29 files changed, 1198 insertions(+), 1238 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8bd211fb..7c90e025 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,8 +9,7 @@ on: jobs: - lint-code: - + black-check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -26,6 +25,22 @@ jobs: - name: Check code with black run: black --check . + ruff-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install ruff + run: pip install ruff + + - name: Check code with ruff + run: ruff check . + run-tests: strategy: diff --git a/ch_util/_db_tables.py b/ch_util/_db_tables.py index 0c977e61..2de37a48 100644 --- a/ch_util/_db_tables.py +++ b/ch_util/_db_tables.py @@ -10,9 +10,9 @@ import chimedb.core import chimedb.data_index from chimedb.core import AlreadyExistsError as AlreadyExists +from chimedb.core.orm import EnumField, base_model, name_table import peewee as pw -import numpy as np # Logging # ======= @@ -88,8 +88,6 @@ class ClosestDraw(chimedb.core.CHIMEdbError): # Helper classes for the peewee ORM # ================================= -from chimedb.core.orm import JSONDictField, EnumField, base_model, name_table - class event_table(base_model): """Baseclass for all models which are linked to the event class.""" @@ -128,7 +126,7 @@ def event( ) if type: try: - dummy = iter(type) + iter(type) ret = ret.where(event.type << type) except TypeError: ret = ret.where(event.type == type) @@ -184,7 +182,7 @@ def set_user(u): """ global _user - _user = dict() + _user = {} # Find the user. if isinstance(u, int): @@ -194,8 +192,8 @@ def set_user(u): else: q = chimedb.core.proxy.execute_sql( "SELECT user_id FROM chimewiki.user " - "WHERE user_name = '%s' OR " - "user_real_name = '%s';" % (u, u) + f"WHERE user_name = '{u}' OR " + f"user_real_name = '{u}';" ) r = q.fetchone() if not r: @@ -219,17 +217,15 @@ def _check_user(perm): ) if perm not in _user["perm"]: try: - p = ( + perm = ( user_permission_type.select() .where(user_permission_type.name == perm) .get() ) - raise NoPermission("You do not have the permissions to %s." % p.long_name) + raise NoPermission(f"You do not have the permissions to {perm.long_name}.") except pw.DoesNotExist: raise RuntimeError( - "Internal error: _check_user called with unknown permission: {}".format( - perm - ) + f"Internal error: _check_user called with unknown permission: {perm}" ) @@ -317,7 +313,8 @@ def start(self, time=datetime.datetime.now(), notes=None): The following starts and ends a new global flag. >>> cat = layout.global_flag_category.get(name = "pass") - >>> flag = layout.global_flag(category = cat, severity = "comment", name = "run_pass12_a").start(time = datetime.datetime(2015, 4, 1, 12)) + >>> flag = layout.global_flag(category = cat, severity = "comment", \ + ... name = "run_pass12_a").start(time = datetime.datetime(2015, 4, 1, 12)) >>> flag.end(time = datetime.datetime(2015, 4, 5, 15, 30)) Parameters @@ -492,7 +489,7 @@ class component(event_table): type = pw.ForeignKeyField(component_type, backref="component") type_rev = pw.ForeignKeyField(component_type_rev, backref="component", null=True) - class Meta(object): + class Meta: indexes = (("sn"), True) def __hash__(self): @@ -534,8 +531,8 @@ def get_connexion( ) if comp: return c.where((connexion.comp1 == comp) | (connexion.comp2 == comp)).get() - else: - return c + + return c def get_history( self, time=datetime.datetime.now(), when=EVENT_AT, order=ORDER_ASC, active=True @@ -591,7 +588,8 @@ def add(self, time=datetime.datetime.now(), notes=None, force=True): >>> lna_type = layout.component_type.get(name = "LNA") >>> lna_rev = lna_type.rev.where(layout.component_type_rev.name == "B").get() - >>> comp = layout.component(sn = "LNA0000A", type = lna_type, rev = lna_type.rev).add() + >>> comp = layout.component(sn = "LNA0000A", type = lna_type, \ + ... rev = lna_type.rev).add() Parameters ---------- @@ -706,8 +704,8 @@ def get_property(self, type=None, time=datetime.datetime.now()): ) if type: return p.where(property.type == type).get() - else: - return p.get() + + return p.get() def set_property(self, type, value, time=datetime.datetime.now(), notes=None): """Set a property for this component. @@ -809,7 +807,7 @@ class connexion(event_table): component, column_name="comp_sn2", field="sn", backref="conn2" ) - class Meta(object): + class Meta: indexes = (("component_sn1", "component_sn2"), True) @classmethod @@ -837,7 +835,7 @@ def from_pair(cls, comp1, comp2, allow_new=True): try: pair.append(component.get(sn=comp)) except pw.DoesNotExist: - raise DoesNotExist("Component %s does not exist." % comp) + raise DoesNotExist(f"Component {comp} does not exist.") q = cls.select().where( ((cls.comp1 == pair[0]) & (cls.comp2 == pair[1])) | ((cls.comp1 == pair[1]) & (cls.comp2 == pair[0])) @@ -845,7 +843,7 @@ def from_pair(cls, comp1, comp2, allow_new=True): if allow_new: try: return q.get() - except: + except pw.DoesNotExist: return cls(comp1=pair[0], comp2=pair[1]) else: return q.get() @@ -926,12 +924,10 @@ def other_comp(self, comp): """ if self.comp1 == comp: return self.comp2 - elif self.comp2 == comp: + if self.comp2 == comp: return self.comp1 - else: - raise DoesNotExist( - "The component you passed is not part of this connexion." - ) + + raise DoesNotExist("The component you passed is not part of this connexion.") def make( self, time=datetime.datetime.now(), permanent=False, notes=None, force=False @@ -1016,7 +1012,7 @@ class property_component(base_model): prop_type = pw.ForeignKeyField(property_type, backref="property_component") comp_type = pw.ForeignKeyField(component_type, backref="property_component") - class Meta(object): + class Meta: indexes = (("prop_type", "comp_type"), True) @@ -1044,7 +1040,7 @@ class property(event_table): type = pw.ForeignKeyField(property_type, backref="property") value = pw.CharField(max_length=255) - class Meta(object): + class Meta: indexes = (("comp_sn, type_id"), False) @@ -1184,7 +1180,7 @@ class event(base_model): start = pw.ForeignKeyField(timestamp, backref="event_start") end = pw.ForeignKeyField(timestamp, backref="event_end") - class Meta(object): + class Meta: indexes = ((("type_id"), False), (("start", "end"), False)) def _event_permission(self): @@ -1238,8 +1234,8 @@ def deactivate(self): False, LayoutIntegrity, "Cannot deactivate because " - "the following history event%s %s set for this " - "component" % (_plural(fail), _are(fail)), + f"the following history event{_plural(fail)} " + f"{_are(fail)} set for this component", ) # Check documents. @@ -1252,8 +1248,8 @@ def deactivate(self): False, LayoutIntegrity, "Cannot deactivate because " - "the following document event%s %s set for this " - "component" % (_plural(fail), _are(fail)), + f"the following document event{_plural(fail)} " + f"{_are(fail)} set for this component", ) # Check properties. @@ -1266,19 +1262,18 @@ def deactivate(self): False, LayoutIntegrity, "Cannot deactivate because " - "the following property event%s %s set for this " - "component" % (_plural(fail), _are(fail)), + f"the following property event{_plural(fail)} " + f"{_are(fail)} set for this component", ) # Check connexions. for conn in comp.get_connexion(when=EVENT_ALL): - fail.append("%s<->%s" % (conn.comp1.sn, conn.comp2.sn)) + fail.append(conn.comp1.sn + "<=>" + conn.comp2.sn) _check_fail( fail, False, LayoutIntegrity, - "Cannot deactivate because " - "the following component%s are connected" % (_plural(fail)), + "Cannot deactivate because " "the following components are connected", ) self.active = False @@ -1325,14 +1320,14 @@ def _replace(self, start=None, end=None, force_end=False): "This method does not currently support moving a " "component availability event later." ) - if start == None: + if start is None: start = self.start else: try: timestamp.get(id=start.id) except pw.DoesNotExist: start.save() - if end == None: + if end is None: if not force_end: end = _pw_getattr(self, "end", None) else: @@ -1345,15 +1340,13 @@ def _replace(self, start=None, end=None, force_end=False): self.active = False self.save() - new = event.create( + return event.create( replaces=self, graph_obj=self.graph_obj, type=self.type, start=start, end=end, ) - self = new - return self class predef_subgraph_spec(name_table): @@ -1399,7 +1392,7 @@ class predef_subgraph_spec_param(base_model): type2 = pw.ForeignKeyField(component_type, backref="subgraph_param2", null=True) action = EnumField(["T", "H", "O"]) - class Meta(object): + class Meta: indexes = (("predef_subgraph_spec", "type", "action"), False) @@ -1434,7 +1427,7 @@ class user_permission(base_model): user_id = pw.IntegerField() type = pw.ForeignKeyField(user_permission_type, backref="user") - class Meta(object): + class Meta: indexes = (("user_id", "type"), False) @@ -1473,7 +1466,7 @@ def _graph_obj_iter(sel, obj, time, when, order, active): ) if active: - ret = ret.where(event.active == True) + ret = ret.where(event.active == True) # noqa E712 if (not when == EVENT_AT) and order: if order == ORDER_ASC: @@ -1504,46 +1497,46 @@ def _check_property_type(ptype, ctype): ).where(property_component.comp_type == ctype).get().name except pw.DoesNotExist: raise PropertyType( - 'Property type "%s" cannot be used for component ' - 'type "%s".' % (ptype.name, ctype.name) + f'Property type "{ptype.name}" cannot be used for component ' + f'type "{ctype.name}".' ) def _check_fail(fail, force, exception, msg): if len(fail): - msg = "%s: %s" % (msg, ", ".join(fail)) + msg += ": " + ", ".join(fail) if force: logger.debug(msg) else: raise exception(msg) -def _conj(l): - if len(l) == 1: +def _conj(obj): + if len(obj) == 1: return "s" - else: - return "" + + return "" -def _plural(l): - if len(l) == 1: +def _plural(obj): + if len(obj) == 1: return "" - else: - return "s" + return "s" -def _does(l): - if len(l) == 1: + +def _does(obj): + if len(obj) == 1: return "does" - else: - return "do" + + return "do" -def _are(l): - if len(l) == 1: +def _are(obj): + if len(obj) == 1: return "is" - else: - return "are" + + return "are" def compare_connexion(conn1, conn2): @@ -1569,10 +1562,7 @@ def compare_connexion(conn1, conn2): sn21 = conn2.comp1.sn sn22 = conn2.comp2.sn - if (sn11 == sn21 and sn12 == sn22) or (sn11 == sn22 and sn12 == sn21): - return True - else: - return False + return (sn11 == sn21 and sn12 == sn22) or (sn11 == sn22 and sn12 == sn21) def add_component(comp, time=datetime.datetime.now(), notes=None, force=False): @@ -1589,8 +1579,10 @@ def add_component(comp, time=datetime.datetime.now(), notes=None, force=False): >>> lna_rev = lna_type.rev.where(layout.component_type_rev.name == "B").get() >>> c = [] >>> for i in range(0, 10): - ... c.append(layout.component(sn = "LNA%04dB" % (i), type = lna_type, rev = lna_rev)) - >>> layout.add_component(c, time = datetime(2014, 10, 10, 11), notes = "Adding many at once.") + ... c.append(layout.component(sn = "LNA%04dB" % (i), type = lna_type, + ... rev = lna_rev)) + >>> layout.add_component(c, time = datetime(2014, 10, 10, 11), \ + ... notes = "Adding many at once.") Parameters ---------- @@ -1629,7 +1621,7 @@ def add_component(comp, time=datetime.datetime.now(), notes=None, force=False): try: c.event(time, event_type.comp_avail(), EVENT_AT).get() fail.append(c.sn) - except: + except pw.DoesNotExist: to_add.append(comp) to_add_sn.append(comp.sn) @@ -1655,7 +1647,8 @@ def add_component(comp, time=datetime.datetime.now(), notes=None, force=False): force, AlreadyExists, "Aborting because the following " - "component%s %s already available at that time" % (_plural(fail), _are(fail)), + f"component{_plural(fail)} " + f"{_are(fail)} already available at that time", ) if len(to_add): @@ -1728,7 +1721,7 @@ def _check_perm_connexion_recurse(comp, time, done=[]): ev_sn += s done.append(c2) else: - fail.append("%s<->%s" % (conn.comp1.sn, conn.comp2.sn)) + fail.append(conn.comp1.sn + "<=>" + conn.comp2.sn) ev.append(comp.event(time, event_type.comp_avail(), EVENT_AT).get()) ev_sn.append(comp.sn) @@ -1780,7 +1773,7 @@ def remove_component(comp, time=datetime.datetime.now(), notes=None, force=False found_conn = False for conn in c.get_connexion(time=time): if not conn.is_permanent(): - fail_conn.append("%s<->%s" % (conn.comp1.sn, conn.comp2.sn)) + fail_conn.append(conn.comp1.sn + "<=>" + conn.comp2.sn) found_conn = True perm_ev, perm_ev_sn, perm_fail = _check_perm_connexion_recurse(c, time) @@ -1803,30 +1796,27 @@ def remove_component(comp, time=datetime.datetime.now(), notes=None, force=False fail_avail, force, LayoutIntegrity, - "The following component%s " - "%s not available at that time, or you have specified an " - "end time earlier than %s start time%s" - % ( - _plural(fail_avail), - _are(fail_avail), - "its" if len(fail_avail) == 1 else "their", - _plural(fail_avail), - ), + f"The following component{_plural(fail_avail)} " + f"{_are(fail_avail)} not available at that time, or you have specified an " + "end time earlier than " + + ("its" if len(fail_avail) == 1 else "their") + + f"start time{_plural(fail_avail)}", ) _check_fail( fail_conn, force, LayoutIntegrity, - "Cannot remove because the " - "following component%s %s connected" % (_plural(fail_conn), _are(fail_conn)), + "Cannot remove because " + f"the following component{_plural(fail_conn)} " + f"{_are(fail_conn)} connected", ) _check_fail( fail_perm_conn, force, LayoutIntegrity, "Cannot remove because " - "the following component%s %s connected (via permanent " - "connexions)" % (_plural(fail_perm_conn), _are(fail_perm_conn)), + f"the following component{_plural(fail_perm_conn)} " + f"{_are(fail_perm_conn)} connected (via permanent connexions)", ) t_stamp = timestamp.create(time=time, notes=notes) @@ -1887,11 +1877,11 @@ def set_property( comp_list = comp for comp in comp_list: _check_property_type(type, comp.type) - if type.regex and value != None: + if type.regex and value is not None: if not re.match(re.compile(type.regex), value): raise ValueError( - 'Value "%s" does not conform to regular ' - "expression %s." % (value, type.regex) + f'Value "{value}" does not conform to regular ' + "expression {type.regex}." ) fail = [] @@ -1953,8 +1943,8 @@ def set_property( p = property.create(id=o, comp=comp, type=type, value=value) e = event.create(graph_obj=o, type=event_type.property(), start=t_stamp) logger.info( - "Added property %s=%s to the following component%s: %s." - % (type.name, value, _plural(to_set), ", ".join(to_set_sn)) + f"Added property {type.name}={value} to the following " + f"component{_plural(to_set)}: " + ", ".join(to_set_sn) ) else: logger.info("No component property was changed.") @@ -1977,7 +1967,8 @@ def make_connexion( ... comp1 = layout.component.get(sn = "LNA%04dB" % (i)) ... comp2 = layout.component.get(sn = "CXA%04dB"% (i)) ... conn.append(layout.connexion.from_pair(comp1, comp2)) - >>> layout.make_connexion(conn, time = datetime(2013, 10, 11, 23, 15), notes = "Making multiple connexions at once.") + >>> layout.make_connexion(conn, time = datetime(2013, 10, 11, 23, 15), \ + ... notes = "Making multiple connexions at once.") Parameters ---------- @@ -2004,10 +1995,10 @@ def make_connexion( to_conn_sn = [] for c in conn: if c.is_connected(time): - fail.append("%s<=>%s" % (c.comp1.sn, c.comp2.sn)) + fail.append(c.comp1.sn + "<=>" + c.comp2.sn) else: to_conn.append(c) - to_conn_sn.append("%s<=>%s" % (c.comp1.sn, c.comp2.sn)) + to_conn_sn.append(c.comp1.sn + "<=>" + c.comp2.sn) if len(fail): _check_fail( fail, @@ -2028,7 +2019,7 @@ def make_connexion( try: conn = connexion.from_pair(c.comp1, c.comp2, allow_new=False) o = conn.id - except: + except pw.DoesNotExist: o = graph_obj.create() conn = connexion.create(id=o, comp1=c.comp1, comp2=c.comp2) if permanent: @@ -2061,7 +2052,8 @@ def sever_connexion(conn, time=datetime.datetime.now(), notes=None, force=False) ... comp1 = layout.component.get(sn = "LNA%04dB" % (i)) ... comp2 = layout.component.get(sn = "CXA%04dB"% (i)) ... conn.append(layout.connexion.from_pair(comp1, comp2)) - >>> layout.sever_connexion(conn, time = datetime(2014, 10, 11, 23, 15), notes = "Severing multiple connexions at once.") + >>> layout.sever_connexion(conn, time = datetime(2014, 10, 11, 23, 15), \ + ... notes = "Severing multiple connexions at once.") Parameters ---------- @@ -2092,22 +2084,22 @@ def sever_connexion(conn, time=datetime.datetime.now(), notes=None, force=False) ev.append( c.event(time=time, type=event_type.connexion(), when=EVENT_AT).get() ) - ev_conn_sn.append("%s<=>%s" % (c.comp1.sn, c.comp2.sn)) + ev_conn_sn.append(c.comp1.sn + "<=>" + c.comp2.sn) except pw.DoesNotExist: try: c.event( time=time, type=event_type.perm_connexion(), when=EVENT_AT ).get() - fail_perm.append("%s<=>%s" % (c.comp1.sn, c.comp2.sn)) + fail_perm.append(c.comp1.sn + "<=>" + c.comp2.sn) except pw.DoesNotExist: - fail_conn.append("%s<=>%s" % (c.comp1.sn, c.comp2.sn)) + fail_conn.append(c.comp1.sn + "<=>" + c.comp2.sn) _check_fail( fail_conn, force, AlreadyExists, "Cannot disconnect because " - "the following connexion%s %s not exist at that time" - % (_plural(fail_conn), _does(fail_conn)), + f"the following connexion{_plural(fail_conn)} " + f"_does{fail_conn} not exist at that time", ) _check_fail( fail_perm, force, LayoutIntegrity, "Cannot disconnect permanent connexions" diff --git a/ch_util/andata.py b/ch_util/andata.py index b3146643..1b117789 100644 --- a/ch_util/andata.py +++ b/ch_util/andata.py @@ -80,11 +80,10 @@ def __new__(cls, h5_data=None, **kwargs): new_cls = subclass_from_obj(cls, h5_data) - self = super(BaseData, new_cls).__new__(new_cls) - return self + return super().__new__(new_cls) def __init__(self, h5_data=None, **kwargs): - super(BaseData, self).__init__(h5_data, **kwargs) + super().__init__(h5_data, **kwargs) if self._data.file.mode == "r+": self._data.require_group("cal") self._data.require_group("flags") @@ -216,17 +215,16 @@ def time(self): # Already a calculated timestamp. return self.index_map["time"][:] - else: - time = _timestamp_from_fpga_cpu( - self.index_map["time"]["ctime"], 0, self.index_map["time"]["fpga_count"] - ) + time = _timestamp_from_fpga_cpu( + self.index_map["time"]["ctime"], 0, self.index_map["time"]["fpga_count"] + ) - alignment = self.index_attrs["time"].get("alignment", 0) + alignment = self.index_attrs["time"].get("alignment", 0) - if alignment != 0: - time = time + alignment * abs(np.median(np.diff(time)) / 2) + if alignment != 0: + time = time + alignment * abs(np.median(np.diff(time)) / 2) - return time + return time @classmethod def _interpret_and_read(cls, acq_files, start, stop, datasets, out_group): @@ -509,7 +507,7 @@ def _interpret_and_read( # Remove the FPGA applied gains (need to invert them first). if apply_gain and any( - [re.match(ACQ_VIS_DATASETS, key) for key in data.datasets] + re.match(ACQ_VIS_DATASETS, key) for key in data.datasets ): from ch_util import tools @@ -697,7 +695,7 @@ def from_acq_h5(cls, acq_files, start=None, stop=None, **kwargs): comm=comm, ) - return super(CorrData, cls).from_acq_h5( + return super().from_acq_h5( acq_files=acq_files, start=start, stop=stop, @@ -756,7 +754,7 @@ def _from_acq_h5_distributed( ) # Load just the local part of the data. - local_data = super(CorrData, cls).from_acq_h5( + local_data = super().from_acq_h5( acq_files=acq_files, start=start, stop=stop, @@ -794,7 +792,8 @@ def _from_acq_h5_distributed( # Iterate over the datasets and copy them over for name, old_dset in local_data.datasets.items(): - # If this should be distributed, extract the sections and turn them into an MPIArray + # If this should be distributed, extract the sections and turn them + # into an MPIArray if name in _DIST_DSETS: array = mpiarray.MPIArray.wrap(old_dset._data, axis=0, comm=comm) else: @@ -811,7 +810,8 @@ def _from_acq_h5_distributed( # Iterate over the flags and copy them over for name, old_dset in local_data.flags.items(): - # If this should be distributed, extract the sections and turn them into an MPIArray + # If this should be distributed, extract the sections and turn them + # into an MPIArray if name in _DIST_DSETS: array = mpiarray.MPIArray.wrap(old_dset._data, axis=0, comm=comm) else: @@ -895,7 +895,7 @@ def from_acq_h5_fast(cls, fname, comm=None, freq_sel=None, start=None, stop=None if freq_sel is None: freq_sel = slice(None) if not isinstance(freq_sel, slice): - raise ValueError("freq_sel must be a slice object, not %s" % repr(freq_sel)) + raise ValueError("freq_sel must be a slice object, not " + repr(freq_sel)) # Create the time selection time_sel = slice(start, stop) @@ -1010,7 +1010,7 @@ def chan(self, mux=-1): try: self._chan except AttributeError: - self._chan = dict() + self._chan = {} try: return self._chan[mux] except KeyError: @@ -1148,7 +1148,7 @@ def from_acq_h5( -------- Examples are analogous to those of :meth:`CorrData.from_acq_h5`. """ - return super(HKData, cls).from_acq_h5( + return super().from_acq_h5( acq_files=acq_files, start=start, stop=stop, @@ -1464,8 +1464,8 @@ def resample(self, metric_name, rule, how="mean", unstack=False, **kwargs): if unstack: return resampled_df.unstack(group_columns) - else: - return resampled_df.reset_index(group_columns) + + return resampled_df.reset_index(group_columns) class WeatherData(BaseData): @@ -1478,8 +1478,8 @@ def time(self): """ if "time" in self.index_map: return self.index_map["time"] - else: - return self.index_map["station_time_blockhouse"] + + return self.index_map["station_time_blockhouse"] @property def temperature(self): @@ -1489,8 +1489,8 @@ def temperature(self): """ if "blockhouse" in self.keys(): return self["blockhouse"]["outTemp"] - else: - return self["outTemp"] + + return self["outTemp"] def dataset_name_allowed(self, name): """Permits datasets in the root and 'blockhouse' groups.""" @@ -1524,13 +1524,13 @@ def dset_filter(dataset): data = dataset else: raise RuntimeError( - "Dataset (%s) has unexpected shape [%s]." - % (dataset.name, repr(dataset.shape)) + f"Dataset ({dataset.name}) " + f"has unexpected shape [{dataset.shape!r}]." ) return data andata_objs = [RawADCData(d) for d in acq_files] - data = concatenate( + return concatenate( andata_objs, out_group=out_group, start=start, @@ -1540,17 +1540,18 @@ def dset_filter(dataset): convert_attribute_strings=cls.convert_attribute_strings, convert_dataset_strings=cls.convert_dataset_strings, ) - return data class GainFlagData(BaseData): """Subclass of :class:`BaseData` for gain, digitalgain, and flag input acquisitions. - These acquisitions consist of a collection of updates to the real-time pipeline ordered - chronologically. In most cases the updates do not occur at a regular cadence. - The time that each update occured can be accessed via `self.index_map['update_time']`. - In addition, each update is given a unique update ID that can be accessed via - `self.datasets['update_id']` and can be searched using the `self.search_update_id` method. + These acquisitions consist of a collection of updates to the real-time + pipeline ordered chronologically. In most cases the updates do not + occur at a regular cadence. The time that each update occured can be + accessed via `self.index_map['update_time']`. In addition, each update + is given a unique update ID that can be accessed via + `self.datasets['update_id']` and can be searched using the + `self.search_update_id` method. """ def resample(self, dataset, timestamp, transpose=False): @@ -1603,18 +1604,15 @@ def search_update_time(self, timestamp): dmax = np.max(timestamp) - np.max(self.time) if dmax > 0.0: - msg = ( + warnings.warn( "Requested timestamps are after the latest update_time " - "by as much as %0.2f hours." % (dmax / 3600.0,) + f"by as much as {dmax / 3600.0:.2f} hours." ) - warnings.warn(msg) - index = np.digitize(timestamp, self.time, right=False) - 1 - - return index + return np.digitize(timestamp, self.time, right=False) - 1 def search_update_id(self, pattern, is_regex=False): - """Find the index into the `update_time` axis corresponding to a particular `update_id`. + """Find the index into the `update_time` axis for a particular `update_id`. Parameters ---------- @@ -1633,14 +1631,13 @@ def search_update_id(self, pattern, is_regex=False): ptn = pattern if is_regex else fnmatch.translate(pattern) regex = re.compile(ptn) - index = np.array( + return np.array( [ii for ii, uid in enumerate(self.update_id[:]) if regex.match(uid)] ) - return index @property def time(self): - """Aliases `index_map['update_time']` to `time` for `caput.tod` functionality.""" + """Returns `index_map['update_time']` for `caput.tod` functionality.""" return self.index_map["update_time"] @property @@ -1835,9 +1832,7 @@ def select_time_range(self, start_time=None, stop_time=None): """ - super(BaseReader, self).select_time_range( - start_time=start_time, stop_time=stop_time - ) + super().select_time_range(start_time=start_time, stop_time=stop_time) def read(self, out_group=None): """Read the selected data. @@ -1872,7 +1867,7 @@ class CorrReader(BaseReader): data_class = CorrData def __init__(self, files): - super(CorrReader, self).__init__(files) + super().__init__(files) data_empty = self._data_empty prod = data_empty.prod freq = data_empty.index_map["freq"] @@ -2102,8 +2097,7 @@ def select_freq_physical(self, frequencies): try: first_match = matches[0][0] except IndexError: - msg = "No match for frequency %f MHz." % frequencies[ii] - raise ValueError(msg) + raise ValueError("No match for frequency {frequencies[ii]} MHz.") freq_inds.append(first_match) self.freq_sel = freq_inds @@ -2224,8 +2218,7 @@ def subclass_from_obj(cls, obj): # If obj is a filename, open it and recurse. if isinstance(obj, str): with h5py.File(obj, "r") as f: - cls = subclass_from_obj(cls, f) - return cls + return subclass_from_obj(cls, f) new_cls = cls acquisition_type = None @@ -2417,8 +2410,7 @@ def _input_sel_from_prod_sel(prod_sel, prod_map): input_sel.append(p0) input_sel.append(p1) # ensure_1D here deals with h5py issue #425. - input_sel = _ensure_1D_selection(sorted(list(set(input_sel)))) - return input_sel + return _ensure_1D_selection(sorted(set(input_sel))) def _prod_sel_from_input_sel(input_sel, input_map, prod_map): @@ -2428,14 +2420,12 @@ def _prod_sel_from_input_sel(input_sel, input_map, prod_map): if p[0] in inputs and p[1] in inputs: prod_sel.append(ii) # ensure_1D here deals with h5py issue #425. - prod_sel = _ensure_1D_selection(prod_sel) - return prod_sel + return _ensure_1D_selection(prod_sel) def _stack_sel_from_prod_sel(prod_sel, stack_rmap): stack_sel = stack_rmap["stack"][prod_sel] - stack_sel = _ensure_1D_selection(sorted(list(set(stack_sel)))) - return stack_sel + return _ensure_1D_selection(sorted(set(stack_sel))) def _prod_sel_from_stack_sel(stack_sel, stack_map, stack_rmap): @@ -2448,8 +2438,7 @@ def _prod_sel_from_stack_sel(stack_sel, stack_map, stack_rmap): for ii in range(len(stack_inds)): prod_sel.append(stack_rmap_sort_inds[left_indeces[ii] : right_indeces[ii]]) prod_sel = np.concatenate(prod_sel) - prod_sel = _ensure_1D_selection(sorted(list(set(prod_sel)))) - return prod_sel + return _ensure_1D_selection(sorted(set(prod_sel))) def versiontuple(v): @@ -2551,7 +2540,8 @@ def _unwrap_fpga_counts(data): # Correct the FPGA counts by adding on the counts lost by wrapping fpga_corrected = time_map["fpga_count"] + num_wraps * 2**32 - # Create an array to represent the new time dataset, and fill in the corrected values + # Create an array to represent the new time dataset, + # and fill in the corrected values _time_dtype = [("fpga_count", np.uint64), ("ctime", np.float64)] new_time_map = np.zeros(time_map.shape, dtype=_time_dtype) new_time_map["fpga_count"] = fpga_corrected @@ -2629,11 +2619,10 @@ def _copy_dataset_acq1( if dataset_name == "vis": # Convert to 64 but complex. if set(split_dsets.keys()) != {"imag", "real"}: - msg = ( + raise ValueError( "Visibilities should have fields 'real' and 'imag'" - " and instead have %s." % str(list(split_dsets.keys())) + " and instead have {split_dsets.keys()}." ) - raise ValueError(msg) vis_data = np.empty(split_dsets["real"].shape, dtype=np.complex64) vis_data.real[:] = split_dsets["real"] vis_data.imag[:] = split_dsets["imag"] @@ -2825,45 +2814,45 @@ def _format_split_acq_dataset_acq1(dataset, time_slice): else: out_cal = {} return {"": out}, out_cal - else: - fields = list(dataset[0].dtype.fields.keys()) - # If there is a 'cal' attribute, make sure it's the right shape. + + fields = list(dataset[0].dtype.fields.keys()) + # If there is a 'cal' attribute, make sure it's the right shape. + if "cal" in dataset.attrs: + if dataset.attrs["cal"].shape != (1,): + msg = "'cal' attribute has more than one element." + raise AttributeError(msg) + if len(list(dataset.attrs["cal"].dtype.fields.keys())) != len(fields): + msg = "'cal' attribute not compatible with dataset dtype." + raise AttributeError(msg) + out = {} + out_cal = {} + # Figure out what fields there are and allocate memory. + for field in fields: + dtype = dataset[0][field].dtype + out_arr = np.empty(out_shape, dtype=dtype) + out[field] = out_arr if "cal" in dataset.attrs: - if dataset.attrs["cal"].shape != (1,): - msg = "'cal' attribute has more than one element." - raise AttributeError(msg) - if len(list(dataset.attrs["cal"].dtype.fields.keys())) != len(fields): - msg = "'cal' attribute not compatible with dataset dtype." - raise AttributeError(msg) - out = {} - out_cal = {} - # Figure out what fields there are and allocate memory. + out_cal[field] = memh5.bytes_to_unicode(dataset.attrs["cal"][0][field]) + for jj, ii in enumerate(np.arange(ntime)[time_slice]): + # Copy data for efficient read. + record = dataset[ii] # Copies to memory. for field in fields: - dtype = dataset[0][field].dtype - out_arr = np.empty(out_shape, dtype=dtype) - out[field] = out_arr - if "cal" in dataset.attrs: - out_cal[field] = memh5.bytes_to_unicode(dataset.attrs["cal"][0][field]) - for jj, ii in enumerate(np.arange(ntime)[time_slice]): - # Copy data for efficient read. - record = dataset[ii] # Copies to memory. - for field in fields: - if not back_shape: - out[field][jj] = record[field] - elif len(back_shape) == 1: - out[field][:, jj] = record[field][:] - else: - # Multidimensional, try to be more efficient. - it = np.nditer(record[..., 0], flags=["multi_index"], order="C") - while not it.finished: - # Reverse the multiindex for the out array. - ind = it.multi_index + (slice(None),) - ind_rev = list(ind) - ind_rev.reverse() - ind_rev = tuple(ind_rev) + (jj,) - out[field][ind_rev] = record[field][ind] - it.iternext() - return out, out_cal + if not back_shape: + out[field][jj] = record[field] + elif len(back_shape) == 1: + out[field][:, jj] = record[field][:] + else: + # Multidimensional, try to be more efficient. + it = np.nditer(record[..., 0], flags=["multi_index"], order="C") + while not it.finished: + # Reverse the multiindex for the out array. + ind = it.multi_index + (slice(None),) + ind_rev = list(ind) + ind_rev.reverse() + ind_rev = tuple(ind_rev) + (jj,) + out[field][ind_rev] = record[field][ind] + it.iternext() + return out, out_cal def _data_attrs_from_acq_attrs_acq1(acq_attrs): @@ -2990,9 +2979,8 @@ def andata_from_acq1(acq_files, start, stop, prod_sel, freq_sel, datasets, out_g vis_shape = dtypes[dataset_name].shape elif dtypes[dataset_name].shape != vis_shape or len(vis_shape) != 2: msg = ( - "Expected the following datasets to be" - " identically shaped and 3D in Acq files: %s." - % str(ACQ_VIS_SHAPE_DATASETS) + "Expected the following datasets to be identically " + f"shaped and 3D in Acq files: {ACQ_VIS_SHAPE_DATASETS!s}." ) raise ValueError(msg) _copy_dataset_acq1( @@ -3260,9 +3248,7 @@ def _generate_input_map(serials, chans=None): else: chan_iter = list(zip(chans, serials)) - imap = np.array(list(chan_iter), dtype=_imap_dtype) - - return imap + return np.array(list(chan_iter), dtype=_imap_dtype) def _get_versiontuple(afile): @@ -3291,7 +3277,7 @@ def _remap_stone_abbot(afile): serial_map = {1: "0003", 33: "0033", -1: "????"} # Stone # Abbot # Unknown # Construct new array of index_map - serial_pat = "29821-0000-%s-C%%i" % serial_map[serial] + serial_pat = "29821-0000-" + serial_map[serial] + "-C%i" inputmap = _generate_input_map([serial_pat % ci for ci in range(8)]) # Copy out old index_map/input if it exists @@ -3316,7 +3302,8 @@ def _remap_blanchard(afile): # Use time to check if blanchard was in the crate or not if last_time < BPC_END: - # Find list of channels and adc serial using different methods depending on the archive file version + # Find list of channels and adc serial using different methods depending + # on the archive file version if _get_versiontuple(afile) < versiontuple("2.0.0"): # The older files have no index_map/input so we need to guess/construct it. chanlist = list(range(16)) @@ -3328,7 +3315,7 @@ def _remap_blanchard(afile): adc_serial = afile.index_map["input"]["adc_serial"][0] # Construct new array of index_map - serial_pat = "29821-0000-%s-C%%02i" % adc_serial + serial_pat = "29821-0000-" + adc_serial + "-C%02i" inputmap = _generate_input_map([serial_pat % ci for ci in chanlist]) else: @@ -3372,7 +3359,8 @@ def _remap_slotX(afile): def _remap_crate_corr(afile, slot): - # Worker routine for remapping the new style files for blanchard, first9ucrate and slotX + # Worker routine for remapping the new style files for blanchard, + # first9ucrate and slotX if _get_versiontuple(afile) < versiontuple("2.0.0"): raise Exception("Only functions with archive 2.0.0 files.") @@ -3508,7 +3496,8 @@ def _insert_gains(data, input_sel): except KeyError: ninput_orig = data.history["acq"]["number_of_antennas"] - # In certain files this entry is a length-1 array, turn it into a scalar if it is not + # In certain files this entry is a length-1 array, + # turn it into a scalar if it is not if isinstance(ninput_orig, np.ndarray): ninput_orig = ninput_orig[0] @@ -3573,7 +3562,8 @@ def _insert_gains(data, input_sel): # Check that the gain datasets have been loaded if ("gain_coeff" not in data.datasets) or ("gain_exp" not in data.datasets): warnings.warn( - "Required gain datasets not loaded from file (> v2.2.0), using unit gains." + "Required gain datasets not loaded from file (> v2.2.0), " + "using unit gains." ) else: diff --git a/ch_util/cal_utils.py b/ch_util/cal_utils.py index fff0f351..a88f9779 100644 --- a/ch_util/cal_utils.py +++ b/ch_util/cal_utils.py @@ -8,7 +8,7 @@ from datetime import datetime import inspect import logging -from typing import Dict, Optional, Union +from typing import Optional, Union import numpy as np import scipy.stats @@ -31,7 +31,7 @@ logger.addHandler(logging.NullHandler()) -class FitTransit(object, metaclass=ABCMeta): +class FitTransit(metaclass=ABCMeta): """Base class for fitting models to point source transits. The `fit` method should be used to populate the `param`, `param_cov`, `chisq`, @@ -60,7 +60,7 @@ class FitTransit(object, metaclass=ABCMeta): """ _tval = {} - component = np.array(["complex"], dtype=np.string_) + component = np.array(["complex"], dtype=np.bytes_) def __init__(self, *args, **kwargs): """Instantiates a FitTransit object. @@ -164,7 +164,7 @@ def fit(self, ha, resp, resp_err, width=5, absolute_sigma=False, **kwargs): dtype = ha.dtype if not np.isscalar(width) and (width.shape != shp): - ValueError("Keyword with must be scalar or have shape %s." % str(shp)) + ValueError(f"Keyword with must be scalar or have shape {shp!s}.") self.param = np.full(shp + (self.nparam,), np.nan, dtype=dtype) self.param_cov = np.full(shp + (self.nparam, self.nparam), np.nan, dtype=dtype) @@ -190,8 +190,8 @@ def fit(self, ha, resp, resp_err, width=5, absolute_sigma=False, **kwargs): absolute_sigma=absolute_sigma, **kwargs, ) - except Exception as error: - logger.debug("Index %s failed with error: %s" % (str(ind), error)) + except (ValueError, KeyError) as error: + logger.debug(f"Index {ind!s} failed with error: {error}") continue self.param[ind] = param @@ -209,7 +209,7 @@ def parameter_names(self): parameter_names : np.ndarray[nparam,] Names of the parameters. """ - return np.array(["param%d" % p for p in range(self.nparam)], dtype=np.string_) + return np.array(["param%d" % p for p in range(self.nparam)], dtype=np.bytes_) @property def param_corr(self): @@ -240,6 +240,7 @@ def N(self): """ if self.param is not None: return self.param.shape[:-1] or None + return None @property def nparam(self): @@ -261,7 +262,8 @@ def ncomponent(self): Returns ------- ncomponent : int - Number of components (i.e, real and imag, amp and phase, complex) that have been fit. + Number of components (i.e, real and imag, amp and phase, + complex) that have been fit. """ return self.component.size @@ -407,7 +409,7 @@ def __init__(self, poly_type="standard", *args, **kwargs): poly_type : str Type of polynomial. Can be 'standard', 'hermite', or 'chebyshev'. """ - super(FitPoly, self).__init__(poly_type=poly_type, *args, **kwargs) + super().__init__(poly_type=poly_type, *args, **kwargs) self._set_polynomial_model(poly_type) @@ -430,8 +432,8 @@ def _set_polynomial_model(self, poly_type): self._root = np.polynomial.chebyshev.chebroots else: raise ValueError( - "Do not recognize polynomial type %s." - "Options are 'standard', 'hermite', or 'chebyshev'." % poly_type + f"Do not recognize polynomial type {poly_type}." + "Options are 'standard', 'hermite', or 'chebyshev'." ) self.poly_type = poly_type @@ -460,7 +462,7 @@ class FitRealImag(FitTransit): methods for predicting the uncertainty on each. """ - component = np.array(["real", "imag"], dtype=np.string_) + component = np.array(["real", "imag"], dtype=np.bytes_) def uncertainty_real(self, ha, alpha=0.32, elementwise=False): """Predicts the uncertainty on real component at given hour angle(s). @@ -526,11 +528,10 @@ def uncertainty(self, ha, alpha=0.32, elementwise=False): Uncertainty on the response. """ with np.errstate(all="ignore"): - err = np.sqrt( + return np.sqrt( self.uncertainty_real(ha, alpha=alpha, elementwise=elementwise) ** 2 + self.uncertainty_imag(ha, alpha=alpha, elementwise=elementwise) ** 2 ) - return err def _jacobian(self, ha): raise NotImplementedError( @@ -571,9 +572,7 @@ def __init__(self, poly_deg=5, even=False, odd=False, *args, **kwargs): if even and odd: raise RuntimeError("Cannot request both even AND odd.") - super(FitPolyRealPolyImag, self).__init__( - poly_deg=poly_deg, even=even, odd=odd, *args, **kwargs - ) + super().__init__(poly_deg=poly_deg, even=even, odd=odd, *args, **kwargs) self.poly_deg = poly_deg self.even = even @@ -756,7 +755,7 @@ def parameter_names(self): return np.array( ["%s_poly_real_coeff%d" % (self.poly_type, p) for p in range(self.nparr)] + ["%s_poly_imag_coeff%d" % (self.poly_type, p) for p in range(self.npari)], - dtype=np.string_, + dtype=np.bytes_, ) def peak(self): @@ -772,7 +771,7 @@ class FitAmpPhase(FitTransit): methods for predicting the uncertainty on each. """ - component = np.array(["amplitude", "phase"], dtype=np.string_) + component = np.array(["amplitude", "phase"], dtype=np.bytes_) def uncertainty_amp(self, ha, alpha=0.32, elementwise=False): """Predicts the uncertainty on amplitude at given hour angle(s). @@ -838,11 +837,10 @@ def uncertainty(self, ha, alpha=0.32, elementwise=False): Uncertainty on the response. """ with np.errstate(all="ignore"): - err = np.abs(self._model(ha, elementwise=elementwise)) * np.sqrt( + return np.abs(self._model(ha, elementwise=elementwise)) * np.sqrt( self.uncertainty_amp(ha, alpha=alpha, elementwise=elementwise) ** 2 + self.uncertainty_phi(ha, alpha=alpha, elementwise=elementwise) ** 2 ) - return err def _jacobian(self, ha): raise NotImplementedError( @@ -878,7 +876,7 @@ def __init__(self, poly_deg_amp=5, poly_deg_phi=5, *args, **kwargs): poly_deg_phi : int Degree of the polynomial to fit to phase. """ - super(FitPolyLogAmpPolyPhase, self).__init__( + super().__init__( poly_deg_amp=poly_deg_amp, poly_deg_phi=poly_deg_phi, *args, **kwargs ) @@ -974,7 +972,9 @@ def _fit( if window is not None: if kk > 0: - center = self.peak(param=coeff) + raise RuntimeError("coeff is not defined") + # Where is `coeff` supposed to be defined? + # center = self.peak(param=coeff) if np.isnan(center): raise RuntimeError("No peak found.") @@ -1147,14 +1147,14 @@ def parameter_names(self): return np.array( ["%s_poly_amp_coeff%d" % (self.poly_type, p) for p in range(self.npara)] + ["%s_poly_phi_coeff%d" % (self.poly_type, p) for p in range(self.nparp)], - dtype=np.string_, + dtype=np.bytes_, ) class FitGaussAmpPolyPhase(FitPoly, FitAmpPhase): """Class that enables fits of a gaussian to amplitude and a polynomial to phase.""" - component = np.array(["complex"], dtype=np.string_) + component = np.array(["complex"], dtype=np.bytes_) npara = 3 def __init__(self, poly_deg_phi=5, *args, **kwargs): @@ -1165,9 +1165,7 @@ def __init__(self, poly_deg_phi=5, *args, **kwargs): poly_deg_phi : int Degree of the polynomial to fit to phase. """ - super(FitGaussAmpPolyPhase, self).__init__( - poly_deg_phi=poly_deg_phi, *args, **kwargs - ) + super().__init__(poly_deg_phi=poly_deg_phi, *args, **kwargs) self.poly_deg_phi = poly_deg_phi self.nparp = poly_deg_phi + 1 @@ -1287,16 +1285,17 @@ def fit_func(x, *param): model_amp = peak_amplitude * np.exp(-4.0 * np.log(2.0) * (dxr / fwhm) ** 2) model_phase = self._eval(xr, poly_coeff) - model = np.concatenate( + return np.concatenate( (model_amp * np.cos(model_phase), model_amp * np.sin(model_phase)) ) - return model - return fit_func def _get_fit_jac(self): - """Generates a function that can be used by `curve_fit` to compute jacobian of the model.""" + """Get `curve_fit` Jacobian + + Generates a function that can be used by `curve_fit` to compute + Jacobian of the model.""" def fit_jac(x, *param): """Function used by `curve_fit` to compute the jacobian. @@ -1424,7 +1423,7 @@ def parameter_names(self): return np.array( ["peak_amplitude", "centroid", "fwhm"] + ["%s_poly_phi_coeff%d" % (self.poly_type, p) for p in range(self.nparp)], - dtype=np.string_, + dtype=np.bytes_, ) @property @@ -1643,32 +1642,32 @@ def fit_point_source_map( "fringe_rate": 22.0 * freq * 1e6 / 3e8, } - lb_dict = { - "peak_amplitude": 0.0, - "centroid_x": ra0 - 1.5, - "centroid_y": dec0 - 0.75, - "fwhm_x": 0.5, - "fwhm_y": 0.5, - "offset": offset0 - 2.0 * np.abs(offset0), - "fringe_rate": -200.0, - } - - ub_dict = { - "peak_amplitude": 1.5 * peak0, - "centroid_x": ra0 + 1.5, - "centroid_y": dec0 + 0.75, - "fwhm_x": 6.0, - "fwhm_y": 6.0, - "offset": offset0 + 2.0 * np.abs(offset0), - "fringe_rate": 200.0, - } + # lb_dict = { + # "peak_amplitude": 0.0, + # "centroid_x": ra0 - 1.5, + # "centroid_y": dec0 - 0.75, + # "fwhm_x": 0.5, + # "fwhm_y": 0.5, + # "offset": offset0 - 2.0 * np.abs(offset0), + # "fringe_rate": -200.0, + # } + + # ub_dict = { + # "peak_amplitude": 1.5 * peak0, + # "centroid_x": ra0 + 1.5, + # "centroid_y": dec0 + 0.75, + # "fwhm_x": 6.0, + # "fwhm_y": 6.0, + # "offset": offset0 + 2.0 * np.abs(offset0), + # "fringe_rate": 200.0, + # } p0 = np.array([p0_dict[key] for key in param_name]) - bounds = ( - np.array([lb_dict[key] for key in param_name]), - np.array([ub_dict[key] for key in param_name]), - ) + # bounds = ( + # np.array([lb_dict[key] for key in param_name]), + # np.array([ub_dict[key] for key in param_name]), + # ) # Define model if do_dirty: @@ -1696,11 +1695,8 @@ def fit_point_source_map( sigma=this_rms, absolute_sigma=True, ) # , bounds=bounds) - except Exception as error: - print( - "index %s: %s" - % ("(" + ", ".join(["%d" % ii for ii in index]) + ")", error) - ) + except ValueError as error: + print("index (" + ", ".join(["%d" % ii for ii in index]) + "):", error) continue # Save the results @@ -2081,11 +2077,13 @@ def fit_histogram( Parameters ---------- arr : np.ndarray - 1D array containing the data. Arrays with more than one dimension are flattened. + 1D array containing the data. Arrays with more than one dimension + are flattened. bins : int or sequence of scalars or str - If `bins` is an int, it defines the number of equal-width bins in `rng`. - - If `bins` is a sequence, it defines a monotonically increasing array of bin edges, - including the rightmost edge, allowing for non-uniform bin widths. + - If `bins` is a sequence, it defines a monotonically increasing + array of bin edges, including the rightmost edge, allowing for + non-uniform bin widths. - If `bins` is a string, it defines a method for computing the bins. rng : (float, float) The lower and upper range of the bins. If not provided, then the range spans @@ -2096,7 +2094,8 @@ def fit_histogram( test_normal : bool Apply the Shapiro-Wilk and Anderson-Darling tests for normality to the data. return_histogram : bool - Return the histogram. Otherwise return only the best fit parameters and test statistics. + Return the histogram. Otherwise return only the best fit parameters + and test statistics. Returns ------- @@ -2119,7 +2118,8 @@ def fit_histogram( pte : float The probability to observe the chi-squared of the fit. - If `return_histogram` is True, then `results` will also contain the following fields: + If `return_histogram` is True, then `results` will also contain the + following fields: bin_centre : np.ndarray The bin centre of the histogram. @@ -2132,7 +2132,7 @@ def fit_histogram( stat : float The Shapiro-Wilk test statistic. pte : float - The probability to observe `stat` if the data were drawn from a gaussian. + The probability to observe `stat` if the data were drawn from a gaussian anderson : dict stat : float The Anderson-Darling test statistic. @@ -2264,8 +2264,8 @@ def flag_outliers(raw, flag, window=25, nsigma=5.0): window : int Window size (in number of samples) used to determine local median. nsigma : float - Data is considered an outlier if it is greater than this number of median absolute - deviations away from the local median. + Data is considered an outlier if it is greater than this number of + median absolute deviations away from the local median. Returns ------- not_outlier : np.ndarray[nsample,] @@ -2315,9 +2315,7 @@ def flag_outliers(raw, flag, window=25, nsigma=5.0): sig = 1.4826 * np.nanmedian(_sliding_window(expanded_resid, rwidth), axis=-1) - not_outlier = resid < (nsigma * sig) - - return not_outlier + return resid < (nsigma * sig) def interpolate_gain(freq, gain, weight, flag=None, length_scale=30.0): @@ -2343,11 +2341,13 @@ def interpolate_gain(freq, gain, weight, flag=None, length_scale=30.0): Returns ------- interp_gain : np.ndarray[nfreq, ninput] - For frequencies with `flag = True`, this will be equal to gain. For frequencies with - `flag = False`, this will be an interpolation of the gains with `flag = True`. + For frequencies with `flag = True`, this will be equal to gain. + For frequencies with `flag = False`, this will be an interpolation + of the gains with `flag = True`. interp_weight : np.ndarray[nfreq, ninput] - For frequencies with `flag = True`, this will be equal to weight. For frequencies with - `flag = False`, this will be the expected uncertainty on the interpolation. + For frequencies with `flag = True`, this will be equal to weight. + For frequencies with `flag = False`, this will be the expected + uncertainty on the interpolation. """ from sklearn import gaussian_process from sklearn.gaussian_process.kernels import Matern, ConstantKernel @@ -2448,9 +2448,7 @@ def interpolate_gain_quiet(*args, **kwargs): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn") - results = interpolate_gain(*args, **kwargs) - - return results + return interpolate_gain(*args, **kwargs) def thermal_amplitude(delta_T, freq): @@ -2492,7 +2490,7 @@ def get_reference_times_file( times: np.ndarray, cal_file: memh5.MemGroup, logger: Optional[logging.Logger] = None, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """For a given set of times determine when and how they were calibrated. This uses the pre-calculated calibration time reference files. @@ -2513,9 +2511,9 @@ def get_reference_times_file( reftime_result : dict A dictionary containing four entries: - - reftime: Unix time of same length as `times`. Reference times of transit of the - source used to calibrate the data at each time in `times`. Returns `NaN` for - times without a reference. + - reftime: Unix time of same length as `times`. Reference times of + transit of the source used to calibrate the data at each time in + `times`. Returns `NaN` for times without a reference. - reftime_prev: The Unix time of the previous gain update. Only set for time samples that need to be interpolated, otherwise `NaN`. - interp_start: The Unix time of the start of the interpolation period. Only @@ -2647,21 +2645,19 @@ def get_reference_times_file( logger.warning(msg.format(n_bad_times, ntimes)) # Bundle result in dictionary - result = { + return { "reftime": reftime, "reftime_prev": reftime_prev, "interp_start": interp_start, "interp_stop": interp_stop, } - return result - def get_reference_times_dataset_id( times: np.ndarray, dataset_ids: np.ndarray, logger: Optional[logging.Logger] = None, -) -> Dict[str, Union[np.ndarray, Dict]]: +) -> dict[str, Union[np.ndarray, dict]]: """Calculate the relevant calibration reference times from the dataset IDs. .. warning:: @@ -2722,7 +2718,7 @@ def get_reference_times_dataset_id( # After restart we sometimes have only a timing update without a source # reference. These aren't valid for our purposes here, and can be distinguished # at the update_id doesn't contain source information, and is thus shorter - d["valid"] = any([src in split_id for src in _source_dict.keys()]) + d["valid"] = any(src in split_id for src in _source_dict.keys()) d["interpolated"] = "transition" in split_id # If it's not a valid update we shouldn't try to extract everything else if not d["valid"]: diff --git a/ch_util/chan_monitor.py b/ch_util/chan_monitor.py index 37e0e844..c69574c9 100644 --- a/ch_util/chan_monitor.py +++ b/ch_util/chan_monitor.py @@ -26,7 +26,7 @@ # _DEFAULT_NODE_SPOOF = {'gong': '/mnt/gong/archive'} # For tests on Marimba -class FeedLocator(object): +class FeedLocator: """This class contains functions that do all the computations to determine feed positions from data. It also determines the quality of data and returns a list of good inputs and frequencies. @@ -233,14 +233,13 @@ def yparams(xx, yy): for ii in range(Nmax): yslp, a = yparams(fr, yslp) a_incr = abs(a - a_prev) / (abs(a + a_prev) * 0.5) - pass_y = a_incr < 1e-2 if a_incr.all(): break - else: - a_prev = a + + a_prev = a # TODO: now it's only one per chan. Change rotation coda appropriatelly - c_ydists = ( + return ( a / 1e6 * C @@ -254,8 +253,6 @@ def yparams(xx, yy): ) ) - return c_ydists - def get_c_ydist_perfreq(self, ph1=None, ph2=None): """Old N-S dists function. TO be used only in case a continuum of frequencies is not available @@ -304,15 +301,13 @@ def get_c_ydist_perfreq(self, ph1=None, ph2=None): [c_ydists + base_up, c_ydists + base_ctr, c_ydists + base_down] ) idxs = np.argmin(abs(dist_opts - c_ydists0[np.newaxis, np.newaxis, :]), axis=0) - c_ydists = np.array( + return np.array( [ [dist_opts[idxs[ii, jj], ii, jj] for jj in range(self.Npr)] for ii in range(self.Nfr) ] ) - return c_ydists - # TODO: change to 'yparams' def params_ft(self, tm, vis, dec, x0_shift=5.0): """Extract relevant parameters from source transit @@ -344,7 +339,6 @@ def params_ft(self, tm, vis, dec, x0_shift=5.0): from scipy.optimize import curve_fit freqs = self.freqs - prods = self.prods # Gaussian function for fit: def gaus(x, A, mu, sig2): @@ -402,8 +396,7 @@ def gaus(x, A, mu, sig2): try: popt, pcov = curve_fit(gaus, fr_ord, ft_ord[ii, jj, :], p0[ii, jj]) prms[ii, jj] = np.array(popt) - # TODO: look for the right exception: - except: + except (KeyError, ValueError): # TODO: Use masked arrays instead of None? prms[ii, jj] = [None] * 3 @@ -425,7 +418,7 @@ def getparams_ft(self): # to make it more consistent def get_xdist(self, ft_prms, dec): """E-W""" - xdists = ( + return ( -ft_prms[..., 1] * SD * C @@ -438,8 +431,6 @@ def get_xdist(self, ft_prms, dec): ) ) - return xdists - def data_quality(self): """ """ if self.pass_xd1 is None: @@ -526,7 +517,7 @@ def get_dists(self): def set_good_ipts(self, base_ipts): """Good_prods to good_ipts""" - inp_list = [inpt for inpt in self.inputs] # Full input list + inp_list = [self.inputs] # Full input list self.good_ipts = np.zeros(self.inputs.shape, dtype=bool) for ii, inprd in enumerate(self.inprds): if inprd[0] not in base_ipts: @@ -560,11 +551,9 @@ def solv_pos(self, dists, base_ipt): # Positions: pstns = np.dot(psd_inv, dists) # Add position of base_input - inp_list = [inpt for inpt in self.inputs] # Full input list + inp_list = [self.inputs] # Full input list bs_inpt_idx = inp_list.index(base_ipt) # Original index of base_ipt - pstns = np.insert(pstns, bs_inpt_idx, 0.0) - - return pstns + return np.insert(pstns, bs_inpt_idx, 0.0) def get_postns(self): """ """ @@ -585,13 +574,14 @@ def xdist_test(self, xds1, xds2=None, tol=2.0): def get_centre(xdists, tol): """Returns the median (across frequencies) of NS separation dists for each baseline if this median is withing *tol* of a multiple of 22 meters. Else, - returns the multiple of 22 meters closest to this median (up to 3*22=66 meters) + returns the multiple of 22 meters closest to this median (up to 3*22=66 + meters) """ xmeds = np.nanmedian(xdists, axis=0) cylseps = np.arange(-1, 2) * 22.0 if self.PATH else np.arange(-3, 4) * 22.0 devs = abs(xmeds[:, np.newaxis] - cylseps[np.newaxis, :]) devmins = devs.min(axis=1) - centres = np.array( + return np.array( [ ( xmeds[ii] # return median @@ -602,8 +592,6 @@ def get_centre(xdists, tol): ] ) - return centres - xcentre1 = get_centre(xds1, tol) xerr1 = abs(xds1 - xcentre1[np.newaxis, :]) self.pass_xd1 = xerr1 < tol @@ -661,7 +649,7 @@ def good_prod_freq( return good_chans, good_freqs -class ChanMonitor(object): +class ChanMonitor: """This class provides the user interface to FeedLocator. It initializes instances of FeedLocator (normally one per polarization) @@ -669,9 +657,9 @@ class ChanMonitor(object): agreement/disagreement with the layout database, etc.) Feed locator should not - have to sepparate the visibilities in data to run the test on and data not to run the - test on. ChanMonitor should make the sepparation and provide FeedLocator with the right - data cube to test. + have to sepparate the visibilities in data to run the test on and data + not to run the test on. ChanMonitor should make the sepparation and + provide FeedLocator with the right data cube to test. Parameters ---------- @@ -789,7 +777,7 @@ def get_src_cndts(self): grds = [grd if clr[ii] else 0 for ii, grd in enumerate(grds)] # Source candidates ordered in decreasing quality - src_cndts = [ + return [ src for grd, src in sorted( zip(grds, srcs), key=lambda entry: entry[0], reverse=True @@ -797,8 +785,6 @@ def get_src_cndts(self): if grd != 0 ] - return src_cndts - def get_pol_prod_idx(self, pol_inpt_idx): """ """ pol_prod_idx = [] @@ -1028,7 +1014,7 @@ def get_data(self): # TODO: correct process_synced_data to not crash when no NS try: self.dat1 = ni_utils.process_synced_data(self.dat1) - except: + except (KeyError, ValueError): pass self.freqs = self.dat1.freq self.prods = self.dat1.prod @@ -1051,7 +1037,7 @@ def get_data(self): # TODO: correct process_synced_data to not crash when no NS try: self.dat2 = ni_utils.process_synced_data(self.dat2) - except: + except (KeyError, ValueError): pass self.tm2 = self.dat2.time break @@ -1072,7 +1058,7 @@ def get_results(self, src, tdelt=2800): f = copy.deepcopy(self.finder) else: f = finder.Finder(node_spoof=_DEFAULT_NODE_SPOOF) - f.filter_acqs((data_index.ArchiveInst.name == "pathfinder")) + f.filter_acqs(data_index.ArchiveInst.name == "pathfinder") f.only_corr() f.set_time_range(self.t1, self.t2) @@ -1093,7 +1079,7 @@ def set_acq_list(self): # Create a Finder object and focus on time range f = finder.Finder(node_spoof=_DEFAULT_NODE_SPOOF) - f.filter_acqs((data_index.ArchiveInst.name == "pathfinder")) + f.filter_acqs(data_index.ArchiveInst.name == "pathfinder") f.only_corr() f.set_time_range(self.t1, self.t2) @@ -1221,8 +1207,8 @@ def full_check(self): if self.source2 is None: if self.source1 is None: raise RuntimeError("No sources available.") - else: - self.single_source_check() + + self.single_source_check() else: Nipts = len(self.inputs) self.good_ipts = np.zeros(Nipts, dtype=bool) @@ -1296,8 +1282,7 @@ def get_test_res(self, fl): self.expostns[ii][0] = fl.expx[jj] self.expostns[ii][1] = fl.expy[jj] - good_frac = float(np.sum(fl.good_ipts)) / float(fl.good_ipts.size) - return good_frac + return float(np.sum(fl.good_ipts)) / float(fl.good_ipts.size) def get_res_sing_src(self, fl): """ """ @@ -1306,5 +1291,4 @@ def get_res_sing_src(self, fl): if fl_ipt == ipt: self.good_ipts[ii] = fl.good_ipts[jj] - good_frac = float(np.sum(fl.good_ipts)) / float(fl.good_ipts.size) - return good_frac + return float(np.sum(fl.good_ipts)) / float(fl.good_ipts.size) diff --git a/ch_util/connectdb.py b/ch_util/connectdb.py index 1b6f7600..638886d8 100644 --- a/ch_util/connectdb.py +++ b/ch_util/connectdb.py @@ -2,22 +2,22 @@ import warnings -warnings.warn( - "The ch_util.connectdb module is deprecated. Use the chimedb package instead" +from chimedb.core.connectdb import ( + NoRouteToDatabase as NoRouteToDatabase, + ConnectionError as ConnectionError, + ALL_RANKS as ALL_RANKS, + current_connector as current_connector, + connect_this_rank as connect_this_rank, + MySQLDatabaseReconnect as MySQLDatabaseReconnect, + BaseConnector as BaseConnector, + MySQLConnector as MySQLConnector, + SqliteConnector as SqliteConnector, + tunnel_active as tunnel_active, + connected_mysql as connected_mysql, + close_mysql as close_mysql, + connect as connect, ) -from chimedb.core.connectdb import ( - NoRouteToDatabase, - ConnectionError, - ALL_RANKS, - current_connector, - connect_this_rank, - MySQLDatabaseReconnect, - BaseConnector, - MySQLConnector, - SqliteConnector, - tunnel_active, - connected_mysql, - close_mysql, - connect, +warnings.warn( + "The ch_util.connectdb module is deprecated. Use the chimedb package instead" ) diff --git a/ch_util/data_index.py b/ch_util/data_index.py index bc998b39..6507a3a8 100644 --- a/ch_util/data_index.py +++ b/ch_util/data_index.py @@ -9,126 +9,131 @@ import warnings -warnings.warn("The ch_util.data_index module is deprecated.") - # Restore all the public symbols -from . import layout, ephemeris - -from chimedb.core.orm import JSONDictField, EnumField, base_model, name_table +from chimedb.core.orm import ( + JSONDictField as JSONDictField, + EnumField as EnumField, + base_model as base_model, + name_table as name_table, +) -from chimedb.core import ValidationError as Validation -from chimedb.core import AlreadyExistsError as AlreadyExists -from chimedb.core import InconsistencyError as DataBaseError +from chimedb.core import ValidationError as Validation # noqa F401 +from chimedb.core import AlreadyExistsError as AlreadyExists # noqa F401 +from chimedb.core import InconsistencyError as DataBaseError # noqa F401 from chimedb.data_index import ( - AcqType, - ArchiveAcq, - ArchiveFile, - ArchiveFileCopy, - ArchiveFileCopyRequest, - ArchiveInst, - CalibrationGainFileInfo, - CorrAcqInfo, - CorrFileInfo, - DigitalGainFileInfo, - FileType, - FlagInputFileInfo, - HKAcqInfo, - HKFileInfo, - HKPFileInfo, - RawadcAcqInfo, - RawadcFileInfo, - StorageGroup, - StorageNode, - WeatherFileInfo, + AcqType as AcqType, + ArchiveAcq as ArchiveAcq, + ArchiveFile as ArchiveFile, + ArchiveFileCopy as ArchiveFileCopy, + ArchiveFileCopyRequest as ArchiveFileCopyRequest, + ArchiveInst as ArchiveInst, + CalibrationGainFileInfo as CalibrationGainFileInfo, + CorrAcqInfo as CorrAcqInfo, + CorrFileInfo as CorrFileInfo, + DigitalGainFileInfo as DigitalGainFileInfo, + FileType as FileType, + FlagInputFileInfo as FlagInputFileInfo, + HKAcqInfo as HKAcqInfo, + HKFileInfo as HKFileInfo, + HKPFileInfo as HKPFileInfo, + RawadcAcqInfo as RawadcAcqInfo, + RawadcFileInfo as RawadcFileInfo, + StorageGroup as StorageGroup, + StorageNode as StorageNode, + WeatherFileInfo as WeatherFileInfo, ) -from chimedb.dataflag import DataFlagType, DataFlag +from chimedb.dataflag import ( + DataFlagType as DataFlagType, + DataFlag as DataFlag, +) from chimedb.data_index.util import ( - fname_atmel, - md5sum_file, - parse_acq_name, - parse_corrfile_name, - parse_weatherfile_name, - parse_hkfile_name, - detect_file_type, + fname_atmel as fname_atmel, + md5sum_file as md5sum_file, + parse_acq_name as parse_acq_name, + parse_corrfile_name as parse_corrfile_name, + parse_weatherfile_name as parse_weatherfile_name, + parse_hkfile_name as parse_hkfile_name, + detect_file_type as detect_file_type, ) -from ._db_tables import connect_peewee_tables as connect_database +from ._db_tables import connect_peewee_tables as connect_database # noqa F401 from .holography import ( - QUALITY_GOOD, - QUALITY_OFFSOURCE, - ONSOURCE_DIST_TO_FLAG, - HolographySource, - HolographyObservation, + QUALITY_GOOD as QUALITY_GOOD, + QUALITY_OFFSOURCE as QUALITY_OFFSOURCE, + ONSOURCE_DIST_TO_FLAG as ONSOURCE_DIST_TO_FLAG, + HolographySource as HolographySource, + HolographyObservation as HolographyObservation, ) -_property = property # Do this since there is a class "property" in _db_tables. from ._db_tables import ( - EVENT_AT, - EVENT_BEFORE, - EVENT_AFTER, - EVENT_ALL, - ORDER_ASC, - ORDER_DESC, - NoSubgraph, - BadSubgraph, - DoesNotExist, - UnknownUser, - NoPermission, - LayoutIntegrity, - PropertyType, - PropertyUnchanged, - ClosestDraw, - event_table, - set_user, - graph_obj, - global_flag_category, - global_flag, - component_type, - component_type_rev, - external_repo, - component, - component_history, - component_doc, - connexion, - property_type, - property_component, - property, - event_type, - timestamp, - event, - predef_subgraph_spec, - predef_subgraph_spec_param, - user_permission_type, - user_permission, - compare_connexion, - add_component, - remove_component, - set_property, - make_connexion, - sever_connexion, - connect_peewee_tables, + EVENT_AT as EVENT_AT, + EVENT_BEFORE as EVENT_BEFORE, + EVENT_AFTER as EVENT_AFTER, + EVENT_ALL as EVENT_ALL, + ORDER_ASC as ORDER_ASC, + ORDER_DESC as ORDER_DESC, + NoSubgraph as NoSubgraph, + BadSubgraph as BadSubgraph, + DoesNotExist as DoesNotExist, + UnknownUser as UnknownUser, + NoPermission as NoPermission, + LayoutIntegrity as LayoutIntegrity, + PropertyType as PropertyType, + PropertyUnchanged as PropertyUnchanged, + ClosestDraw as ClosestDraw, + event_table as event_table, + set_user as set_user, + graph_obj as graph_obj, + global_flag_category as global_flag_category, + global_flag as global_flag, + component_type as component_type, + component_type_rev as component_type_rev, + external_repo as external_repo, + component as component, + component_history as component_history, + component_doc as component_doc, + connexion as connexion, + property_type as property_type, + property_component as property_component, + property as property, + event_type as event_type, + timestamp as timestamp, + event as event, + predef_subgraph_spec as predef_subgraph_spec, + predef_subgraph_spec_param as predef_subgraph_spec_param, + user_permission_type as user_permission_type, + user_permission as user_permission, + compare_connexion as compare_connexion, + add_component as add_component, + remove_component as remove_component, + set_property as set_property, + make_connexion as make_connexion, + sever_connexion as sever_connexion, + connect_peewee_tables as connect_peewee_tables, ) from .finder import ( - GF_REJECT, - GF_RAISE, - GF_WARN, - GF_ACCEPT, - Finder, - DataIntervalList, - BaseDataInterval, - CorrDataInterval, - DataInterval, - HKDataInterval, - WeatherDataInterval, - FlagInputDataInterval, - CalibrationGainDataInterval, - DigitalGainDataInterval, - files_in_range, - DataFlagged, + GF_REJECT as GF_REJECT, + GF_RAISE as GF_RAISE, + GF_WARN as GF_WARN, + GF_ACCEPT as GF_ACCEPT, + Finder as Finder, + DataIntervalList as DataIntervalList, + BaseDataInterval as BaseDataInterval, + CorrDataInterval as CorrDataInterval, + DataInterval as DataInterval, + HKDataInterval as HKDataInterval, + WeatherDataInterval as WeatherDataInterval, + FlagInputDataInterval as FlagInputDataInterval, + CalibrationGainDataInterval as CalibrationGainDataInterval, + DigitalGainDataInterval as DigitalGainDataInterval, + files_in_range as files_in_range, + DataFlagged as DataFlagged, ) + +warnings.warn("The ch_util.data_index module is deprecated.") diff --git a/ch_util/data_quality.py b/ch_util/data_quality.py index 6260c68b..6d33f8eb 100644 --- a/ch_util/data_quality.py +++ b/ch_util/data_quality.py @@ -108,7 +108,8 @@ def good_channels( And to create a plot of the results: - >>> good_gains, good_noise, good_fit, test_chans = good_channels(data,test_freq=3,res_plot=True) + >>> good_gains, good_noise, good_fit, test_chans = \ + ... good_channels(data,test_freq=3,res_plot=True) """ @@ -131,9 +132,9 @@ def good_channels( n_samp = t_step * bwdth # Processing noise synced data, if noise_synced != False: - if noise_synced == False: + if noise_synced is False: pass - elif noise_synced == True: + elif noise_synced is True: if is_gated_format: # If data is gated, ignore noise_synced argument: msg = ( @@ -145,12 +146,12 @@ def good_channels( else: # Process noise synced data: data = ni_utils.process_synced_data(data) - elif noise_synced == None: + elif noise_synced is None: # If noise_synced is not given, try to read ni_enable from data: try: # Newer data have a noise-injection flag ni_enable = data.attrs["fpga.ni_enable"][0].astype(bool) - except: + except KeyError: # If no info is found, run function to determine ni_enable: ni_enable = _check_ni(data, test_freq) # If noise injection is enabled and data is not gated: @@ -164,7 +165,7 @@ def good_channels( autos_index, autos_chan = _get_autos_index(prod_array_full) # Select auto-corrs and test_freq only: visi = np.array([data.vis[test_freq, jj, :] for jj in autos_index]) - chan_array = np.array([chan for chan in autos_chan]) + chan_array = np.array(autos_chan) tmstp = data.index_map["time"]["ctime"] # Remove non-chime channels (Noise source, RFI, 26m...): @@ -434,7 +435,7 @@ def _noise_test(visi, tmstp, n_samp, tol): rnt[ii] = rnt_med # List of good noise channels (Initialized with all True): - good_noise = np.ones((Nchans)) + good_noise = np.ones(Nchans) # Test noise against tolerance and isnan, isinf: for ii in range(Nchans): is_nan_inf = np.isnan(rnt[ii]) or np.isinf(rnt[ii]) @@ -469,7 +470,7 @@ def _radiom_noise(trace, n_samp, wind=100): # Use MAD to estimate RMS. More robust against RFI/correlator spikes. # sqrt(2) factor is due to my subtracting even - odd time bins. # 1.4826 factor is to go from MAD to RMS of a normal distribution: - # rms = [ np.std(entry)/np.sqrt(2) for entry in t_s ] # Using MAD to estimate rms for now + # rms = [ np.std(entry)/np.sqrt(2) for entry in t_s ] # Using MAD to estimate rms rms = [ np.median([np.abs(entry[ii] - np.median(entry)) for ii in range(len(entry))]) * 1.4826 @@ -633,30 +634,28 @@ def _stats_print(good_noise, good_gains, good_fit, test_chans): Nact = len(test_chans) # Number of active channels if good_noise is not None: Nnoisy = Nact - int(np.sum(good_noise)) + percent = Nnoisy * 100.0 / Nact print( - "Noisy channels: {0} out of {1} active channels ({2:2.1f}%)".format( - Nnoisy, Nact, Nnoisy * 100 / Nact - ) + f"Noisy channels: {Nnoisy} out of {Nact} active channels ({percent:2.1f}%)" ) good_chans = good_chans * good_noise else: Nnoisy = None if good_gains is not None: Ngains = Nact - int(np.sum(good_gains)) + percent = Ngains * 100.0 / Nact print( - "High digital gains: {0} out of {1} active channels ({2:2.1f}%)".format( - Ngains, Nact, Ngains * 100 / Nact - ) + "High digital gains: " + f"{Ngains} out of {Nact} active channels ({percent:2.1f}%)" ) good_chans = good_chans * good_gains else: Ngains = None if good_fit is not None: Nfit = Nact - int(np.sum(good_fit)) + percent = Nfit * 100.0 / Nact print( - "Bad fit to T_sky: {0} out of {1} active channels ({2:2.1f}%)".format( - Nfit, Nact, Nfit * 100 / Nact - ) + f"Bad fit to T_sky: {Nfit} out of {Nact} active channels ({percent:2.1f}%)" ) good_chans = good_chans * good_fit else: @@ -666,11 +665,8 @@ def _stats_print(good_noise, good_gains, good_fit, test_chans): if not ((good_noise is None) and (good_gains is None) and (good_fit is None)): Nbad = Nact - int(np.sum(good_chans)) - print( - "Overall bad: {0} out of {1} active channels ({2:2.1f}%)\n".format( - Nbad, Nact, Nbad * 100 / Nact - ) - ) + percent = Nbad * 100.0 / Nact + print(f"Overall bad: {Nbad} out of {Nact} active channels ({percent:2.1f}%)\n") else: Nbad = None @@ -726,10 +722,9 @@ def _median_filter(visi, ks=3): from scipy.signal import medfilt # Median filter visibilities: - cut_vis = np.array( + return np.array( [medfilt(visi[jj, :].real, kernel_size=ks) for jj in range(visi.shape[0])] ) - return cut_vis def _get_template(cut_vis_full, stand_chans): @@ -775,7 +770,7 @@ def _get_template(cut_vis_full, stand_chans): indices = np.delete(indices, max_ind) # Cut-out channels with largest deviations: cut_vis = np.array( - [cut_vis[jj, :] for jj in range(len(cut_vis)) if not (jj in del_ind)] + [cut_vis[jj, :] for jj in range(len(cut_vis)) if jj not in del_ind] ) Nchans = cut_vis.shape[0] # Number of channels after cut @@ -840,7 +835,7 @@ def _fit_template(Ts, cut_vis, tol): from scipy.optimize import curve_fit - class Template(object): + class Template: def __init__(self, tmplt): self.tmplt = tmplt @@ -903,9 +898,7 @@ def _create_plot( tmstp2 = cut_tmstp # For title, use start time stamp: - title = "Good channels result for {0}".format( - ctime.unix_to_datetime(tmstp1[0]).date() - ) + title = f"Good channels result for {ctime.unix_to_datetime(tmstp1[0]).date()}" # I need to know the slot for each channel: def get_slot(channel): @@ -973,7 +966,7 @@ def get_slot(channel): plt.ylim(med - 7.0 * mad, med + 7.0 * mad) # labels: - plt.ylabel("Ch{0} (Sl.{1})".format(chan, get_slot(chan)), fontsize=8) + plt.ylabel(f"Ch{chan} (Sl.{get_slot(chan)})", fontsize=8) # Hide numbering: frame = plt.gca() @@ -989,16 +982,12 @@ def get_slot(channel): # Put x-labels on bottom plots: if time_unit == "days": plt.xlabel( - "Time (days since {0} UTC)".format( - ctime.unix_to_datetime(tmstp1[0]) - ), + f"Time (days since {ctime.unix_to_datetime(tmstp1[0])} UTC)", fontsize=10, ) else: plt.xlabel( - "Time (hours since {0} UTC)".format( - ctime.unix_to_datetime(tmstp1[0]) - ), + f"Time (hours since {ctime.unix_to_datetime(tmstp1[0])} UTC)", fontsize=10, ) @@ -1011,7 +1000,7 @@ def get_slot(channel): elif chan == 192: plt.title("East cyl. P2(E-W)", fontsize=12) - filename = "plot_fit_{0}.pdf".format(int(time.time())) + filename = f"plot_fit_{int(time.time())}.pdf" plt.savefig(filename) plt.close() - print("Finished creating plot. File name: {0}".format(filename)) + print(f"Finished creating plot. File name: {filename}") diff --git a/ch_util/ephemeris.py b/ch_util/ephemeris.py index 712d7410..303d625f 100644 --- a/ch_util/ephemeris.py +++ b/ch_util/ephemeris.py @@ -2,46 +2,61 @@ This module is deprecated. For CHIME-specific stuff, use `ch_ephem`. + For instrument-independent stuff, use `caput`. """ import warnings -warnings.warn("The ch_util.ephemeris module is deprecated.", DeprecationWarning) - -from caput.interferometry import sphdist +from caput.interferometry import sphdist as sphdist from caput.time import ( - unix_to_datetime, - datetime_to_unix, - datetime_to_timestr, - timestr_to_datetime, - leap_seconds_between, - time_of_day, - Observer, - unix_to_skyfield_time, - skyfield_time_to_unix, - skyfield_star_from_ra_dec, - skyfield_wrapper, - ensure_unix, - SIDEREAL_S, - STELLAR_S, + unix_to_datetime as unix_to_datetime, + datetime_to_unix as datetime_to_unix, + datetime_to_timestr as datetime_to_timestr, + timestr_to_datetime as timestr_to_datetime, + leap_seconds_between as leap_seconds_between, + time_of_day as time_of_day, + Observer as Observer, + unix_to_skyfield_time as unix_to_skyfield_time, + skyfield_time_to_unix as skyfield_time_to_unix, + skyfield_star_from_ra_dec as skyfield_star_from_ra_dec, + skyfield_wrapper as skyfield_wrapper, + ensure_unix as ensure_unix, + SIDEREAL_S as SIDEREAL_S, + STELLAR_S as STELLAR_S, ) -from ch_ephem.coord import star_cirs as Star_cirs -from ch_ephem.coord import peak_ra as peak_RA +from ch_ephem.coord import star_cirs as Star_cirs # noqa F401 +from ch_ephem.coord import peak_ra as peak_RA # noqa F401 from ch_ephem.coord import ( - cirs_radec, - object_coords, - hadec_to_bmxy, - bmxy_to_hadec, - get_range_rate, + cirs_radec as cirs_radec, + object_coords as object_coords, + hadec_to_bmxy as hadec_to_bmxy, + bmxy_to_hadec as bmxy_to_hadec, + get_range_rate as get_range_rate, +) +from ch_ephem.pointing import ( + galt_pointing_model_ha as galt_pointing_model_ha, + galt_pointing_model_dec as galt_pointing_model_dec, +) +from ch_ephem.sources import ( + get_source_dictionary as get_source_dictionary, + source_dictionary as source_dictionary, + CasA as CasA, + CygA as CygA, + TauA as TauA, + VirA as VirA, +) +from ch_ephem.time import ( + parse_date as parse_date, + utc_lst_to_mjd as utc_lst_to_mjd, + chime_local_datetime as chime_local_datetime, ) -from ch_ephem.pointing import galt_pointing_model_ha, galt_pointing_model_dec -from ch_ephem.sources import get_source_dictionary, CasA, CygA, TauA, VirA -from ch_ephem.time import parse_date, utc_lst_to_mjd, chime_local_datetime from ch_ephem.observers import chime, tone, kko, gbo, hco -from .hfbcat import get_doppler_shifted_freq +from .hfbcat import get_doppler_shifted_freq as get_doppler_shifted_freq + +warnings.warn("The ch_util.ephemeris module is deprecated.", DeprecationWarning) CHIMELATITUDE = chime.latitude CHIMELONGITUDE = chime.longitude diff --git a/ch_util/finder.py b/ch_util/finder.py index 3aca06be..5fb58cc6 100644 --- a/ch_util/finder.py +++ b/ch_util/finder.py @@ -39,23 +39,22 @@ """ import logging -import os from os import path import time import socket import peewee as pw -import re import caput.time as ctime from ch_ephem.observers import chime -import chimedb.core as db +from chimedb.core.exceptions import CHIMEdbError import chimedb.data_index as di -from . import layout - +from chimedb.data_index.orm import file_info_table from chimedb.dataflag import DataFlagType, DataFlag +from . import layout +from ._db_tables import connect_peewee_tables as connect_database from .holography import HolographySource, HolographyObservation # Module Constants @@ -67,12 +66,6 @@ GF_ACCEPT = "gf_accept" -# Initializing connection to database. -# ==================================== - -from ._db_tables import connect_peewee_tables as connect_database - - # High level interface to the data index # ====================================== @@ -80,12 +73,8 @@ # finder. _acq_info_table = [di.CorrAcqInfo, di.HKAcqInfo, di.RawadcAcqInfo] -# Import list of tables that have a ``start_time`` and ``end_time`` -# field: they are necessary to do any time-based search. -from chimedb.data_index.orm import file_info_table - -class Finder(object): +class Finder: """High level searching of the CHIME data index. This class gives a convenient way to search and filter data acquisitions @@ -218,19 +207,19 @@ class Finder(object): 3 | 20141009T222415Z_ben_hk | 0 | 5745 | 2 | 0 >>> res = f.get_results(file_condition = (di.HKFileInfo.atmel_name == "LNA")) >>> for r in res: - ... print "No. files: %d" % (len(r[0])) + ... print("No. files: %d" % (len(r[0])) No. files: 8 No. files: 1 No. files: 19 No. files: 1 >>> data = res[0].as_loaded_data() >>> for m in data.mux: - ... print "Mux %d: %s", (m, data.chan(m)) + ... print("Mux %d: %s", (m, data.chan(m))) Mux 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] Mux 1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] Mux 2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] - >>> print "Here are the raw data for Mux 1, Channel 14:", data.tod(14, 1) - Here are the raw data for Mux 1, Channel 14: [ 1744.19091797 1766.34472656 1771.03356934 ..., 1928.61279297 1938.90075684 1945.53491211] + >>> print(data.tod(14, 1)) + [1744.190917, 1766.344726, 1771.033569, ..., 1928.612792, 1938.900756, 1945.534912] In the above example, the restriction to LNA housekeeping could also have been accomplished with the convenience method :meth:`Finder.set_hk_input`: @@ -374,8 +363,8 @@ def time_intervals(self): if self._time_intervals is None: return [self.time_range] - else: - return list(self._time_intervals) + + return list(self._time_intervals) def _append_time_interval(self, interval): if self._time_intervals is None: @@ -686,7 +675,7 @@ def set_time_range(self, start_time=None, end_time=None): ) self.filter_acqs_by_files(cond) - if not self._time_intervals is None: + if self._time_intervals is not None: time_intervals = _trim_intervals_range( self.time_intervals, (start_time, end_time) ) @@ -744,8 +733,8 @@ def _format_time_interval(self, start_time, end_time): end_time = min(end_time, range_end) if start_time < end_time: return (start_time, end_time) - else: - return None + + return None def include_time_interval(self, start_time, end_time): """Include a time interval. @@ -1018,7 +1007,7 @@ def include_26m_obs(self, source, require_quality=True): sources = sources.where(HolographySource.name == source) if len(sources) == 0: msg = ( - "No sources found in the database that match: {0}\n".format(source) + f"No sources found in the database that match: {source}\n" + "Returning full time range" ) logging.warning(msg) @@ -1030,7 +1019,7 @@ def include_26m_obs(self, source, require_quality=True): if require_quality: obs = obs.select().where( (HolographyObservation.quality_flag == 0) - | (HolographyObservation.quality_flag == None) + | (HolographyObservation.quality_flag == None) # noqa E712 ) found_obs = False @@ -1043,10 +1032,8 @@ def include_26m_obs(self, source, require_quality=True): self.include_time_interval(ob.start_time, ob.finish_time) if not found_obs: msg = ( - "No observation of the source ({0}) was found within the time range.\n".format( - source - ) - + "Returning full time range" + f"No observation of the source ({source}) " + "was found within the time range. Returning full time range" ) logging.warning(msg) @@ -1195,37 +1182,34 @@ def get_results_acq(self, acq_ind, file_condition=None): if mode is GF_ACCEPT: # Do nothing. continue - else: - # Need to actually get the flags. - global_flags = layout.global_flags_between( - acq_start, acq_finish, severity - ) - global_flag_names = [gf.name for gf in global_flags] - flag_times = [] - for f in global_flags: - start, stop = layout.get_global_flag_times(f.id) - if stop is None: - stop = time.time() - start = ctime.ensure_unix(start) - stop = ctime.ensure_unix(stop) - flag_times.append((start, stop)) - overlap = _check_intervals_overlap(time_intervals, flag_times) + + # Need to actually get the flags. + global_flags = layout.global_flags_between(acq_start, acq_finish, severity) + global_flag_names = [gf.name for gf in global_flags] + flag_times = [] + for f in global_flags: + start, stop = layout.get_global_flag_times(f.id) + if stop is None: + stop = time.time() + start = ctime.ensure_unix(start) + stop = ctime.ensure_unix(stop) + flag_times.append((start, stop)) + overlap = _check_intervals_overlap(time_intervals, flag_times) + if mode is GF_WARN: if overlap: msg = ( - "Global flag with severity '%s' present in data" + f"Global flag with severity '{severity}' present in data" " search results and warning requested." - " Global flag name: %s" - % (severity, global_flag_names[overlap[1]]) + " Global flag name: " + global_flag_names[overlap[1]] ) logging.warning(msg) elif mode is GF_RAISE: if overlap: msg = ( - "Global flag with severity '%s' present in data" + f"Global flag with severity '{severity}' present in data" " search results and exception requested." - " Global flag name: %s" - % (severity, global_flag_names[overlap[1]]) + " Global flag name: " + global_flag_names[overlap[1]] ) raise DataFlagged(msg) elif mode is GF_REJECT: @@ -1239,8 +1223,8 @@ def get_results_acq(self, acq_ind, file_condition=None): if len(self.data_flag_types) > 0: df_types = [t.name for t in DataFlagType.select()] for dft in self.data_flag_types: - if not dft in df_types: - raise RuntimeError("Could not find data flag type {}.".format(dft)) + if dft not in df_types: + raise RuntimeError(f"Could not find data flag type {dft}.") flag_times = [] for f in DataFlag.select().where( DataFlag.type == DataFlagType.get(name=dft) @@ -1386,7 +1370,7 @@ def print_results_summary(self): total_data += length total_size += s interval_number += 1 - print("Total %6.f seconds, %6.f MB of data." % (total_data, total_size)) + print(f"Total {total_data:6.f} seconds, {total_size:6.f} MB of data.") def _trim_intervals_range(intervals, time_range, min_interval=0.0): @@ -1397,8 +1381,8 @@ def _trim_intervals_range(intervals, time_range, min_interval=0.0): end = min(end, range_end) if end <= start + min_interval: continue - else: - out.append((start, end)) + + out.append((start, end)) return out @@ -1426,10 +1410,11 @@ def _check_intervals_overlap(intervals1, intervals2): start2, stop2 = intervals2[jj] if start1 < stop2 and start2 < stop1: return ii, jj + return None def _validate_gf_value(value): - if not value in (GF_REJECT, GF_RAISE, GF_WARN, GF_ACCEPT): + if value not in (GF_REJECT, GF_RAISE, GF_WARN, GF_ACCEPT): raise ValueError( "Global flag behaviour must be one of" " the *GF_REJECT*, *GF_RAISE*, *GF_WARN*, *GF_ACCEPT*" @@ -1441,7 +1426,7 @@ def _get_global_flag_times_by_name_event_id(flag): if isinstance(flag, str): event = ( layout.event.select() - .where(layout.event.active == True) + .where(layout.event.active == True) # noqa: E712 .join( layout.global_flag, on=(layout.event.graph_obj == layout.global_flag.id) ) @@ -1568,8 +1553,7 @@ def as_loaded_data(self, **kwargs): for k, v in kwargs.items(): if v is not None: setattr(reader, k, v) - data = reader.read() - return data + return reader.read() class CorrDataInterval(BaseDataInterval): @@ -1600,7 +1584,7 @@ def as_loaded_data(self, prod_sel=None, freq_sel=None, datasets=None): Data interval loaded into memory. """ - return super(CorrDataInterval, self).as_loaded_data( + return super().as_loaded_data( prod_sel=prod_sel, freq_sel=freq_sel, datasets=datasets ) @@ -1728,16 +1712,13 @@ def files_in_range( if not node_spoof: return [path.join(af.root, acq_name, af.name) for af in query] - else: - return [path.join(node_spoof[af.node_name], acq_name, af.name) for af in query] + + return [path.join(node_spoof[af.node_name], acq_name, af.name) for af in query] # Exceptions # ========== -# This is the base CHIMEdb exception -from chimedb.core.exceptions import CHIMEdbError - class DataFlagged(CHIMEdbError): """Raised when data is affected by a global flag.""" diff --git a/ch_util/fluxcat.py b/ch_util/fluxcat.py index 26e61327..636a957e 100644 --- a/ch_util/fluxcat.py +++ b/ch_util/fluxcat.py @@ -42,7 +42,7 @@ # ================================================================================== -class FitSpectrum(object, metaclass=ABCMeta): +class FitSpectrum(metaclass=ABCMeta): """A base class for modeling and fitting spectra. Any spectral model used by FluxCatalog should be derived from this class. @@ -146,12 +146,12 @@ class CurvedPowerLaw(FitSpectrum): freq_pivot : float The pivot frequency :math:`\\nu' = \\nu / freq_pivot`. Default is :py:const:`FREQ_NOMINAL`. - """ + """ # noqa: E501 def __init__(self, freq_pivot=FREQ_NOMINAL, nparam=2, *args, **kwargs): """Instantiates a CurvedPowerLaw object""" - super(CurvedPowerLaw, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # Set the additional model kwargs self.freq_pivot = freq_pivot @@ -271,7 +271,7 @@ def __contains__(self, item): return obj is not None -class FluxCatalog(object, metaclass=MetaFluxCatalog): +class FluxCatalog(metaclass=MetaFluxCatalog): """ Class for cataloging astronomical sources and predicting their flux density at radio frequencies based on spectral fits @@ -379,7 +379,7 @@ def __init__( # overwrite argument. if (overwrite < 2) and (name in FluxCatalog): # Return existing entry - print("%s already has an entry in catalog." % name, end=" ") + print(f"{name} already has an entry in catalog.", end=" ") if overwrite == 0: print("Returning existing entry.") self = FluxCatalog[name] @@ -442,9 +442,10 @@ def __init__( # Add alternate names to class dictionary so they can be searched quickly for alt_name in self.alternate_names: if alt_name in self._alternate_name_lookup: + alt_source = self._alternate_name_lookup[alt_name] warnings.warn( - "The alternate name %s is already held by the source %s." - % (alt_name, self._alternate_name_lookup[alt_name]) + f"The alternate name {alt_name} is already " + f"held by the source {alt_source}." ) else: self._alternate_name_lookup[alt_name] = self.name @@ -556,7 +557,6 @@ def plot(self, legend=True, catalog=True, residuals=False): Default is False. """ - import matplotlib import matplotlib.pyplot as plt # Define plot parameters @@ -710,9 +710,7 @@ def predict_flux(self, freq, epoch=None): else: args = [freq] - flux = self._model.predict(*args) - - return flux + return self._model.predict(*args) def predict_uncertainty(self, freq, epoch=None): """Calculate the uncertainty in the estimate of the flux density @@ -740,9 +738,7 @@ def predict_uncertainty(self, freq, epoch=None): else: args = [freq] - flux_uncertainty = self._model.uncertainty(*args) - - return flux_uncertainty + return self._model.uncertainty(*args) def to_dict(self): """Returns an ordered dictionary containing attributes @@ -768,21 +764,13 @@ def __str__(self): """Returns a string containing basic information about the source. Called by the print statement. """ - source_string = ( - "{0:<25.25s} {1:>6.2f} {2:>6.2f} {3:>6d} {4:^15.1f} {5:^15.1f}".format( - self.name, - self.ra, - self.dec, - len(self), - self.predict_flux(FREQ_NOMINAL), - 100.0 - * self.predict_uncertainty(FREQ_NOMINAL) - / self.predict_flux(FREQ_NOMINAL), - ) + flux = self.predict_flux(FREQ_NOMINAL) + percent = 100.0 * self.predict_uncertainty(FREQ_NOMINAL) / flux + return ( + f"{self.name:<25.25s} {self.ra:>6.2f} {self.dec:>6.2f} " + f"{len(self):>6d} {flux:^15.1f} {percent:^15.1f}" ) - return source_string - def __len__(self): """Returns the number of measurements of the source.""" return len(self.measurements) if self.measurements is not None else 0 @@ -793,20 +781,22 @@ def print_measurements(self): out = [] # Define header - hdr = "{0:<10s} {1:>8s} {2:>8s} {3:>6s} {4:>8s} {5:>8s} {6:<60s}".format( - "Frequency", "Flux", "Error", "Flag", "Catalog", "Epoch", "Citation" + hdr = ( + f"{'Frequency':<10s} {'Flux':>8s} {'Error':>8s} {'Flag':>6s} " + f"{'Catalog':>8s} {'Epoch':>8s} {'Citation':<60s}" ) - units = "{0:<10s} {1:>8s} {2:>8s} {3:>6s} {4:>8s} {5:>8s} {6:<60s}".format( - "[MHz]", "[Jy]", "[%]", "", "", "", "" + units = ( + f"{'[MHz]':<10s} {'[Jy]':>8s} {'[%]':>8s} " + f"{'':>6s} {'':>8s} {'':>8s} {'':<60s}" ) # Setup the title out.append("".join(["="] * max(len(hdr), len(units)))) - out.append("NAME: {0:s}".format(self.name.replace("_", " "))) - out.append("RA: {0:>6.2f} deg".format(self.ra)) - out.append("DEC: {0:>6.2f} deg".format(self.dec)) - out.append("{0:d} Measurements".format(len(self.measurements))) + out.append("NAME: " + self.name.replace("_", " ")) + out.append(f"RA: {self.ra:>6.2f} deg") + out.append(f"DEC: {self.dec:>6.2f} deg") + out.append(f"{len(self.measurements):d} Measurements") out.append("".join(["-"] * max(len(hdr), len(units)))) out.append(hdr) @@ -910,17 +900,14 @@ def string(cls): catalog_string = [] # Print the header - hdr = "{0:<25s} {1:^6s} {2:^6s} {3:>6s} {4:^15s} {5:^15s}".format( - "Name", "RA", "Dec", "Nmeas", "Flux", "Error" + hdr = ( + f"{'Name':<25s} {'RA':^6s} {'Dec':^6s} " + f"{'Nmeas':>6s} {'Flux':^15s} {'Error':^15s}" ) - units = "{0:<25s} {1:^6s} {2:^6s} {3:>6s} {4:^15s} {5:^15s}".format( - "", - "[deg]", - "[deg]", - "", - "@%d MHz [Jy]" % FREQ_NOMINAL, - "@%d MHz [%%]" % FREQ_NOMINAL, + units = ( + f"{'':<25s} {'[deg]':^6s} {'[deg]':^6s} {'':>6s} " + f"{f'@{FREQ_NOMINAL} MHz [Jy]':^15s} {f'@{FREQ_NOMINAL} MHz [%]':^15s}" ) catalog_string.append("".join(["-"] * max(len(hdr), len(units)))) @@ -1003,7 +990,7 @@ def get(cls, key): # Check if the object was found if obj is None: - raise KeyError("%s was not found." % fkey) + raise KeyError(f"{fkey} was not found.") # Return the body corresponding to this source return obj @@ -1147,7 +1134,7 @@ def available_collections(cls): full_path = os.path.join(root, filename) # Read into dictionary - with open(full_path, "r") as fp: + with open(full_path) as fp: collection_dict = json.load(fp, object_hook=json_numpy_obj_hook) # Append (path, number of sources, source names) to list @@ -1219,7 +1206,7 @@ def dump(cls, filename): ext = os.path.splitext(filename)[1] if ext not in [".pickle", ".json"]: - raise ValueError("Do not recognize '%s' extension." % ext) + raise ValueError(f"Do not recognize '{ext}' extension.") try: os.makedirs(path) @@ -1273,13 +1260,13 @@ def load(cls, filename, overwrite=0, set_globals=False, verbose=False): # Check if the file actually exists and has the correct extension if not os.path.isfile(filename): - raise ValueError("%s does not exist." % filename) + raise ValueError(f"{filename} does not exist.") if ext not in [".pickle", ".json"]: - raise ValueError("Do not recognize '%s' extension." % ext) + raise ValueError(f"Do not recognize '{ext}' extension.") # Load contents of file into a dictionary - with open(filename, "r") as fp: + with open(filename) as fp: if ext == ".json": collection_dict = json.load(fp, object_hook=json_numpy_obj_hook) elif ext == ".pickle": @@ -1430,11 +1417,8 @@ def format_source_name(input_name): # Remove multiple spaces. Replace single spaces with underscores. output_name = "_".join(output_name.split()) - # Put the name in all uppercase. - output_name = output_name.upper() - - # Return properly formatted name - return output_name + # Return the name in all uppercase. + return output_name.upper() class NumpyEncoder(json.JSONEncoder): @@ -1450,7 +1434,11 @@ def default(self, obj): assert cont_obj.flags["C_CONTIGUOUS"] obj_data = cont_obj.data data_b64 = base64.b64encode(obj_data) - return dict(__ndarray__=data_b64, dtype=str(obj.dtype), shape=obj.shape) + return { + "__ndarray__": data_b64, + "dtype": str(obj.dtype), + "shape": obj.shape, + } # Let the base class default method raise the TypeError return json.JSONEncoder(self, obj) diff --git a/ch_util/hfbcat.py b/ch_util/hfbcat.py index 013fc4f3..032e28d0 100644 --- a/ch_util/hfbcat.py +++ b/ch_util/hfbcat.py @@ -3,6 +3,7 @@ """ from __future__ import annotations + import numpy as np from typing import TYPE_CHECKING, Optional, Union @@ -156,8 +157,8 @@ def get_doppler_shifted_freq( "in ch_util.hfbcat.HFBCatalog. " f"Got source {source} with names {source.names}" ) - else: - freq_rest = HFBCatalog[source.names].freq_abs + + freq_rest = HFBCatalog[source.names].freq_abs # Prepare rest frequencies for broadcasting freq_rest = np.asarray(ensure_list(freq_rest))[:, np.newaxis] @@ -169,9 +170,7 @@ def get_doppler_shifted_freq( # Compute observed frequencies from rest frequencies # using relativistic Doppler effect beta = range_rate / speed_of_light - freq_obs = freq_rest * np.sqrt((1.0 - beta) / (1.0 + beta)) - - return freq_obs + return freq_rest * np.sqrt((1.0 - beta) / (1.0 + beta)) # Load the HFB target list diff --git a/ch_util/holography.py b/ch_util/holography.py index fa6e3bb4..c8a8485c 100644 --- a/ch_util/holography.py +++ b/ch_util/holography.py @@ -199,26 +199,27 @@ def parse_post_report(cls, post_report_file): output_params = {} - with open(post_report_file, "r") as f: - lines = [line for line in f] - for l in lines: - if (l.find("Source")) != -1: - srcnm = re.search("Source:\s+(.*?)\s+", l).group(1) + with open(post_report_file) as f: + for line in f: + if (line.find("Source")) != -1: + srcnm = re.search("Source:\s+(.*?)\s+", line).group(1) if srcnm in cls.source_alias: srcnm = cls.source_alias[srcnm] - if (l.find("DURATION")) != -1: + if (line.find("DURATION")) != -1: output_params["DURATION"] = float( - re.search("DURATION:\s+(.*?)\s+", l).group(1) + re.search("DURATION:\s+(.*?)\s+", line).group(1) ) # convert Julian Date to Skyfield time object - if (l.find("JULIAN DATE")) != -1: + if (line.find("JULIAN DATE")) != -1: output_params["start_time"] = ts.ut1( - jd=float(re.search("JULIAN DATE:\s+(.*?)\s+", l).group(1)) + jd=float(re.search("JULIAN DATE:\s+(.*?)\s+", line).group(1)) ) - if l.find("SID:") != -1: - output_params["SID"] = int(re.search("SID:\s(.*?)\s+", l).group(1)) + if line.find("SID:") != -1: + output_params["SID"] = int( + re.search("SID:\s(.*?)\s+", line).group(1) + ) try: output_params["src"] = HolographySource.get(name=srcnm) except pw.DoesNotExist: @@ -285,7 +286,7 @@ def create_from_ant_logs( ant_log["dec"], ) if verbose: - print("onsource_dist = {:.2f} deg".format(onsource_dist)) + print(f"onsource_dist = {onsource_dist:.2f} deg") onsource = np.where(dist.degrees < onsource_dist)[0] if len(onsource) > 0: @@ -303,14 +304,13 @@ def create_from_ant_logs( if stdoffset > 0.05 or meanoffset > ONSOURCE_DIST_TO_FLAG: obs["quality_flag"] += QUALITY_OFFSOURCE print( - ( - "Mean offset: {:.4f}. Std offset: {:.4f}. " - "Setting quality flag to {}." - ).format(meanoffset, stdoffset, QUALITY_OFFSOURCE) + f"Mean offset: {meanoffset:.4f}. " + f"Std offset: {stdoffset:.4f}. " + f"Setting quality flag to {QUALITY_OFFSOURCE}." ) noteout = ( "Questionable on source. Mean, STD(offset) : " - "{:.3f}, {:.3f}. {}".format(meanoffset, stdoffset, noteout) + f"{meanoffset:.3f}, {stdoffset:.3f}. {noteout}" ) obs["quality_flag"] |= quality_flag if verbose: @@ -331,9 +331,8 @@ def create_from_ant_logs( ) ) print( - "Mean offset: {:.4f}. Std offset: {:.4f}.".format( - meanoffset, stdoffset - ) + f"Mean offset: {meanoffset:.4f}. " + f"Std offset: {stdoffset:.4f}." ) cls.create_from_dict(obs, verbose=verbose, notes=noteout, **kwargs) @@ -449,23 +448,20 @@ def check_for_duplicates(t, src, start_tol, ignore_src_mismatch=False): # check possible. if dup_found: - tf = ts.utc(ctime.unix_to_datetime(entry.finish_time)) print( - "Tried to add : {} {}; LST={:.3f}".format( - src.name, t.utc_datetime().strftime(DATE_FMT_STR), ttlst - ) + f"Tried to add : {src.name} " + f"{t.utc_datetime().strftime(DATE_FMT_STR)}; " + f"LST={ttlst:.3f}" ) print( - "Existing entry: {} {}; LST={:.3f}".format( - entry.source.name, - tt.utc_datetime().strftime(DATE_FMT_STR), - ttlst, - ) + f"Existing entry: {entry.source.name} " + f"{tt.utc_datetime().strftime(DATE_FMT_STR)}; " + f"LST={ttlst:.3f}" ) if dup_found: return existing_db_entry - else: - return None + + return None # DRAO longitude in hours DRAO_lon = chime.longitude * 24.0 / 360.0 @@ -565,8 +561,6 @@ def parse_ant_logs(cls, logs, return_post_report_params=False): from skyfield.positionlib import Angle - DRAO_lon = chime.longitude * 24.0 / 360.0 - def sidlst_to_csd(sid, lst, sid_ref, t_ref): """ Convert an integer DRAO sidereal day and LST to a float @@ -601,7 +595,7 @@ def sidlst_to_csd(sid, lst, sid_ref, t_ref): doobs = True filename = log.split("/")[-1] - basedir = "/tmp/26mlog/{}/".format(os.getlogin()) + basedir = f"/tmp/26mlog/{os.getlogin()}/" basename, extension = filename.split(".") post_report_file = basename + ".POST_REPORT" ant_file = basename + ".ANT" @@ -609,20 +603,18 @@ def sidlst_to_csd(sid, lst, sid_ref, t_ref): if extension == "zip": try: zipfile.ZipFile(log).extract(post_report_file, path=basedir) - except: + except ValueError: print( - "Failed to extract {} into {}. Moving right along...".format( - post_report_file, basedir - ) + f"Failed to extract {post_report_file} into {basedir}. " + "Moving right along..." ) doobs = False try: zipfile.ZipFile(log).extract(ant_file, path=basedir) - except: + except ValueError: print( - "Failed to extract {} into {}. Moving right along...".format( - ant_file, basedir - ) + f"Failed to extract {ant_file} into {basedir}. " + "Moving right along..." ) doobs = False @@ -632,8 +624,7 @@ def sidlst_to_csd(sid, lst, sid_ref, t_ref): basedir + post_report_file ) - with open(os.path.join(basedir, ant_file), "r") as f: - lines = [line for line in f] + with open(os.path.join(basedir, ant_file)) as f: ant_data = {"sid": np.array([])} lsth = [] lstm = [] @@ -647,8 +638,8 @@ def sidlst_to_csd(sid, lst, sid_ref, t_ref): decm = [] decs = [] - for l in lines: - arr = l.split() + for line in f: + arr = line.split() try: lst_hms = [float(x) for x in arr[2].split(":")] @@ -672,12 +663,8 @@ def sidlst_to_csd(sid, lst, sid_ref, t_ref): ant_data["sid"] = np.append( ant_data["sid"], int(arr[1]) ) - except: - print( - "Failed in file {} for line \n{}".format( - ant_file, l - ) - ) + except IndexError: + print(f"Failed in file {ant_file} for line \n{line}") if len(ant_data["sid"]) != len(decs): print("WARNING: mismatch in list lengths.") @@ -735,8 +722,8 @@ def sidlst_to_csd(sid, lst, sid_ref, t_ref): ant_data_list.append(ant_data) post_report_list.append(post_report_params) - except: - print("Parsing {} failed".format(post_report_file)) + except (ValueError, KeyError): + print(f"Parsing {post_report_file} failed") if return_post_report_params: return post_report_list, ant_data_list @@ -784,12 +771,13 @@ def create_from_post_reports( Example ------- - from ch_util import holography as hl - import glob - obs = hl.HolographyObservation - logs = glob.glob('/path/to/logs/*JUN18*.zip') - obs_list, dup_obs_list, missing = obs.create_from_post_reports(logs, dryrun=False) + >>> from ch_util import holography as hl + >>> import glob + >>> obs = hl.HolographyObservation + >>> logs = glob.glob('/path/to/logs/*JUN18*.zip') + >>> obs_list, dup_obs_list, missing = \ + ... obs.create_from_post_reports(logs, dryrun=False) """ # check notes. Can be a string (in which case duplicate it), None (in # which case do nothing) or a list (in which case use it if same length @@ -807,7 +795,7 @@ def create_from_post_reports( for log, note in zip(logs, notesarr): if verbose: - print("Working on {}".format(log)) + print(f"Working on {log}") filename = log.split("/")[-1] # basedir = '/'.join(log.split('/')[:-1]) + '/' basedir = "/tmp/" @@ -820,12 +808,8 @@ def create_from_post_reports( if extension == "zip": try: zipfile.ZipFile(log).extract(post_report_file, path=basedir) - except Exception: - print( - "failed to find {}. Moving right along...".format( - post_report_file - ) - ) + except ValueError: + print(f"failed to find {post_report_file}. Moving right along...") doobs = False elif extension != "POST_REPORT": print( diff --git a/ch_util/layout.py b/ch_util/layout.py index 19609b44..0969a63c 100644 --- a/ch_util/layout.py +++ b/ch_util/layout.py @@ -125,69 +125,70 @@ - :py:const:`EVENT_ALL` - :py:const:`ORDER_ASC` - :py:const:`ORDER_DESC` -""" +""" # noqa: E501 import datetime -import inspect import logging +from logging import NullHandler import networkx as nx import os import peewee as pw -import re import chimedb.core import caput.time as ctime -_property = property # Do this since there is a class "property" in _db_tables. +# Do this since there is a class "property" in _db_tables. +from builtins import property as _property from ._db_tables import ( - EVENT_AT, - EVENT_BEFORE, - EVENT_AFTER, - EVENT_ALL, - ORDER_ASC, - ORDER_DESC, - _check_fail, - _plural, - _are, - AlreadyExists, - NoSubgraph, - BadSubgraph, - DoesNotExist, - UnknownUser, - NoPermission, - LayoutIntegrity, - PropertyType, - PropertyUnchanged, - ClosestDraw, - set_user, - graph_obj, - global_flag_category, - global_flag, - component_type, - component_type_rev, - external_repo, - component, - component_history, - component_doc, - connexion, - property_type, - property_component, - property, - event_type, - timestamp, - event, - predef_subgraph_spec, - predef_subgraph_spec_param, - user_permission_type, - user_permission, - compare_connexion, - add_component, - remove_component, - set_property, - make_connexion, - sever_connexion, + EVENT_AT as EVENT_AT, + EVENT_BEFORE as EVENT_BEFORE, + EVENT_AFTER as EVENT_AFTER, + EVENT_ALL as EVENT_ALL, + ORDER_ASC as ORDER_ASC, + ORDER_DESC as ORDER_DESC, + _check_fail as _check_fail, + _plural as _plural, + _are as _are, + AlreadyExists as AlreadyExists, + NoSubgraph as NoSubgraph, + BadSubgraph as BadSubgraph, + DoesNotExist as DoesNotExist, + UnknownUser as UnknownUser, + NoPermission as NoPermission, + LayoutIntegrity as LayoutIntegrity, + PropertyType as PropertyType, + PropertyUnchanged as PropertyUnchanged, + ClosestDraw as ClosestDraw, + set_user as set_user, + graph_obj as graph_obj, + global_flag_category as global_flag_category, + global_flag as global_flag, + component_type as component_type, + component_type_rev as component_type_rev, + external_repo as external_repo, + component as component, + component_history as component_history, + component_doc as component_doc, + connexion as connexion, + property_type as property_type, + property_component as property_component, + property as property, + event_type as event_type, + timestamp as timestamp, + event as event, + predef_subgraph_spec as predef_subgraph_spec, + predef_subgraph_spec_param as predef_subgraph_spec_param, + user_permission_type as user_permission_type, + user_permission as user_permission, + compare_connexion as compare_connexion, + add_component as add_component, + remove_component as remove_component, + set_property as set_property, + make_connexion as make_connexion, + sever_connexion as sever_connexion, ) +from ._db_tables import connect_peewee_tables as connect_database # Legacy name from chimedb.core import NotFoundError as NotFound @@ -197,13 +198,10 @@ # Logging # ======= -# Set default logging handler to avoid "No handlers could be found for logger -# 'layout'" warnings. -from logging import NullHandler - - # All peewee-generated logs are logged to this namespace. logger = logging.getLogger("layout") +# Set default logging handler to avoid "No handlers could be found for logger +# 'layout'" warnings. logger.addHandler(NullHandler()) @@ -211,7 +209,7 @@ # ======= -class subgraph_spec(object): +class subgraph_spec: """Specifications for extracting a subgraph from a full graph. The subgraph specification can be created from scratch by passing the @@ -271,7 +269,7 @@ class method :meth:`FROM_PREDef`. Most of them are as short as we would expect, but there are some complications. Let's look at that first one by printing out its LTF: - >>> print sg[0].ltf + >>> print(sg[0].ltf) # C-can thru to RFT thru. CANAD0B RFTA15B attenuation=10 therm_avail=ch7 @@ -303,7 +301,7 @@ class method :meth:`FROM_PREDef`. >>> sg_spec.terminate += ["200m coax", "HK hydra", "50m coax"] >>> sg_spec.hide += ["200m coax", "HK hydra", "50m coax"] >>> sg = layout.graph.from_db(datetime(2014, 10, 5, 12, 0), sg_spec) - >>> print [s.order() for s in sg] + >>> print([s.order() for s in sg]) [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, @@ -323,7 +321,7 @@ class method :meth:`FROM_PREDef`. CANBL1B >>> sg_spec.hide = [] >>> bad_sg = layout.graph.from_db(datetime(2014, 10, 5, 12, 0), sg_spec, sn) - >>> print bad_sg.ltf() + >>> print(bad_sg.ltf()) # C-can thru to c-can thru. CANBL1B CXS0017 @@ -342,7 +340,7 @@ class method :meth:`FROM_PREDef`. >>> sg_spec.oneway = [["SMA coax", "RFT thru"]] >>> bad_sg = layout.graph.from_db(datetime(2014, 10, 5, 12, 0), sg_spec, sn) - >>> print bad_sg.ltf() + >>> print(bad_sg.ltf()) # C-can thru to RFT thru. CANBL1B CXS0017 @@ -380,7 +378,7 @@ def from_predef(cls, predef): elif param.action == "H": h.append(param.type1_id) else: - raise RuntimeError('Unknown subgraph action type "%s".' % param.action) + raise RuntimeError(f'Unknown subgraph action type "{param.action}".') return cls(s, t, o, h) @_property @@ -453,14 +451,16 @@ class and adds CHIME-specific functionality. `networkx.Graph `_ methods: - >>> print g.order(), g.size() + >>> print(g.order(), g.size()) 2483 2660 There are some convenience methods for our implementation. For example, you can easily find components by component type: - >>> print g.component(type = "reflector") - [, , ] + >>> print(g.component(type = "reflector")) + [, + , + ] Note that the graph nodes are :obj:`component` objects. You can also use the :meth:`component` method to search for components by serial number: @@ -469,15 +469,17 @@ class and adds CHIME-specific functionality. Node properties are stored as per usual for :class:`networkx.Graph` objects: - >>> print g.nodes[ant] - {'_rev_id': 11L, '_type_id': 2L, u'pol1_orient': , '_type_name': u'antenna', '_id': 32L, u'pol2_orient': , '_rev_name': u'B'} + >>> print(g.nodes[ant]) + {'_rev_id': 11L, '_type_id': 2L, '_type_name': 'antenna', '_id': 32L, + 'pol1_orient': , + 'pol2_orient': , '_rev_name': 'B'} Note, however, that there are some internally-used properties (starting with an underscore). The :meth:`node_property` returns a dictionary of properties without these private memebers: >>> for p in g.node_property(ant).values(): - ... print "%s = %s %s" % (p.type.name, p.value, p.type.units if p.type.units else "") + ... print(p.type.name, "=", p.value, p.type.units if p.type.units else "") pol1_orient = S pol2_orient = E @@ -485,7 +487,7 @@ class and adds CHIME-specific functionality. component, using :meth:`closest_of_type`: >>> slt_type = layout.component_type.get(name = "cassette slot") - >>> print g.closest_of_type(ant, slt_type).sn + >>> print(g.closest_of_type(ant, slt_type).sn) CSS004C0 Use of :meth:`closest_of_type` can be subtle for components separated by long @@ -504,8 +506,8 @@ def __init__(self, time=datetime.datetime.now()): self._time = time self._sg_spec = None self._sg_spec_start = None - self._sn_dict = dict() - self._ctype_dict = dict() + self._sn_dict = {} + self._ctype_dict = {} # We will cache all the component types, revisions and properties now, # since these will be used constantly by the graph. @@ -572,13 +574,12 @@ def from_db(cls, time=datetime.datetime.now(), sg_spec=None, sg_start_sn=None): "JOIN timestamp t1 ON e.start_id = t1.id " "LEFT JOIN timestamp t2 ON e.end_id = t2.id " "WHERE e.active = 1 AND e1.type_id = 1 AND e2.type_id = 1 AND " - "e1t1.time <= '%s' AND " - "(e1.end_id IS NULL OR e1t2.time > '%s') AND " - "e2t1.time <= '%s' AND " - "(e2.end_id IS NULL OR e2t2.time > '%s') AND " - "t1.time <= '%s' AND " - "(e.end_id IS NULL OR t2.time > '%s');" - % (time, time, time, time, time, time) + f"e1t1.time <= '{time}' AND " + f"(e1.end_id IS NULL OR e1t2.time > '{time}') AND " + f"e2t1.time <= '{time}' AND " + f"(e2.end_id IS NULL OR e2t2.time > '{time}') AND " + f"t1.time <= '{time}' AND " + f"(e.end_id IS NULL OR t2.time > '{time}');" ) # print sql conn_list = chimedb.core.proxy.execute_sql(sql) @@ -604,10 +605,10 @@ def from_db(cls, time=datetime.datetime.now(), sg_spec=None, sg_start_sn=None): "JOIN timestamp t1 ON e.start_id = t1.id " "LEFT JOIN timestamp t2 ON e.end_id = t2.id " "WHERE e.active = 1 AND ce.type_id = 1 AND " - "ct1.time <= '%s' AND " - "(ce.end_id IS NULL OR ct2.time > '%s') AND " - "t1.time <= '%s' AND " - "(e.end_id IS NULL OR t2.time > '%s');" % (time, time, time, time) + f"ct1.time <= '{time}' AND " + f"(ce.end_id IS NULL OR ct2.time > '{time}') AND " + f"t1.time <= '{time}' AND " + f"(e.end_id IS NULL OR t2.time > '{time}');" ) prop_list = chimedb.core.proxy.execute_sql(sql) for r in prop_list: @@ -618,8 +619,8 @@ def from_db(cls, time=datetime.datetime.now(), sg_spec=None, sg_start_sn=None): if sg_spec: return graph.from_graph(g, sg_spec, sg_start_sn) - else: - return g + + return g def _ensure_add(self, id, sn, type, rev): """Robustly add a component, avoiding duplication.""" @@ -667,11 +668,11 @@ def node_property(self, n): >>> g = layout.graph.from_db(datetime(2014, 10, 5, 12, 0)) >>> rft = g.component(comp = "RFTK07B") >>> for p in g.node_property(rft).values(): - ... print "%s = %s %s" % (p.type.name, p.value, p.type.units if p.type.units else "") + ... print(p.type.name, "=", p.value, p.type.units if p.type.units else "") attenuation = 10 dB therm_avail = ch1 """ - ret = dict() + ret = {} for key, val in self.nodes[n].items(): if key[0] != "_": ret[key] = val @@ -712,7 +713,7 @@ def component(self, comp=None, type=None, sort_sn=False): >>> from ch_util import graph >>> from datetime import datetime >>> g = layout.graph.from_db(datetime(2014, 10, 5, 12, 0)) - >>> print g.component("CXA0005A").type_rev.name + >>> print(g.component("CXA0005A").type_rev.name) B >>> for r in g.component(type = "reflector"): ... print r.sn @@ -730,7 +731,7 @@ def component(self, comp=None, type=None, sort_sn=False): try: ret = self._sn_dict[sn] except KeyError: - raise NotFound('Serial number "%s" is not in the graph.' % (sn)) + raise NotFound('Serial number "{sn}" is not in the graph.') elif not type: ret = self.nodes() else: @@ -745,9 +746,7 @@ def component(self, comp=None, type=None, sort_sn=False): if sort_sn: ret.sort(key=lambda x: x.sn) except KeyError: - raise NotFound( - 'No components of type "%s" are in the graph.' % type_name - ) + raise NotFound(f'No components of type "{type_name}" are in the graph.') return ret def _subgraph_recurse(self, gr, comp1, sg, done, last_no_hide): @@ -814,7 +813,7 @@ def from_graph(cls, g, sg_spec=None, sg_start_sn=None): A list of :obj:`graph` objects, one for each subgraph found. If, however, *g* is set to :obj:`None`, a reference to the input graph is returned. """ - if sg_spec == None: + if sg_spec is None: return g if sg_spec.start in sg_spec.terminate: raise BadSubgraph( @@ -841,8 +840,8 @@ def from_graph(cls, g, sg_spec=None, sg_start_sn=None): raise NotFound("No subgraph was found.") if sg_start_sn: return ret[-1] - else: - return ret + + return ret def _print_chain(self, chain): if len(chain) <= 1: @@ -851,11 +850,11 @@ def _print_chain(self, chain): ret = "" ctype1 = chain[0].type.name ctype2 = chain[-1].type.name - ret = "# %s to %s.\n" % (ctype1[0].upper() + ctype1[1:], ctype2) + ret = "# " + (ctype1[0].upper() + ctype1[1:]) + " to " + ctype2 + ".\n" for c in chain: ret += c.sn for prop, value in self.node_property(c).items(): - ret += " %s=%s" % (prop, value.value) + ret += " " + prop + "=" + value.value ret += "\n" ret += "\n" @@ -923,8 +922,8 @@ def ltf(self): layout.component_type.get(name = "HK preamp").id, layout.component_type.get(name = "HK hydra").id] >>> sg_spec = layout.subgraph_spec(start, terminate, [], hide) - >>> sg = layout.graph.from_db(datetime(2014, 11, 20, 12, 0), sg_spec, "ANT0108B") - >>> print sg.ltf() + >>> sg = layout.graph.from_db(datetime(2014, 11, 20, 12, 0), sg_spec,"ANT0108B") + >>> print(sg.ltf()) # Antenna to correlator input. ANT0108B pol1_orient=S pol2_orient=E PL0108B1 @@ -1052,8 +1051,8 @@ def shortest_path_to_type(self, comp, type, type_exclude=None, ignore_draws=True # Return the shortest path (or None if not found) if one: return shortest[0] - else: - return shortest + + return shortest def closest_of_type(self, comp, type, type_exclude=None, ignore_draws=True): """Searches for the closest connected component of a given type. @@ -1099,38 +1098,41 @@ def closest_of_type(self, comp, type, type_exclude=None, ignore_draws=True): >>> import layout >>> from datetime import datetime >>> g = layout.graph.from_db(datetime(2014, 11, 5, 12, 0)) - >>> print g.closest_of_type("ANT0044B", "cassette slot").sn + >>> print(g.closest_of_type("ANT0044B", "cassette slot").sn) CSS004C0 The example above is simple as the two components are adjacent: - >>> print [c.sn for c in g.shortest_path_to_type("ANT0044B", "cassette slot")] - [u'ANT0044B', u'CSS004C0'] + >>> print([c.sn for c in g.shortest_path_to_type("ANT0044B", "cassette slot")]) + ['ANT0044B', 'CSS004C0'] In general, though, you need to take care when using this method and make judicious use of the **type_exclude** parameter. For example, consider the following example: - >>> print g.closest_of_type("K7BP16-00040112", "RFT thru").sn + >>> print(g.closest_of_type("K7BP16-00040112", "RFT thru").sn) RFTB15B It seems OK on the surface, but the path it has used is probably not what you want: - >>> print [c.sn for c in g.shortest_path_to_type("K7BP16-00040112", "RFT thru")] - [u'K7BP16-00040112', u'K7BP16-000401', u'K7BP16-00040101', u'FLA0280B', u'RFTB15B'] + >>> print([c.sn for c in g.shortest_path_to_type("K7BP16-00040112","RFT thru")]) + ['K7BP16-00040112', 'K7BP16-000401', 'K7BP16-00040101', 'FLA0280B', 'RFTB15B'] We need to block the searcher from going into the correlator card slot and then back out another input, which we can do like so: - >>> print g.closest_of_type("K7BP16-00040112", "RFT thru", type_exclude = "correlator card slot").sn + >>> print(g.closest_of_type("K7BP16-00040112", "RFT thru", + ... type_exclude = "correlator card slot").sn) RFTQ15B The reason the first search went through the correlator card slot is because there are delay cables and splitters involved. - >>> print [c.sn for c in g.shortest_path_to_type("K7BP16-00040112", "RFT thru", type_exclude = "correlator card slot")] - [u'K7BP16-00040112', u'CXS0279', u'CXA0018A', u'CXA0139B', u'SPL001AP2', u'SPL001A', u'SPL001AP3', u'CXS0281', u'RFTQ15B'] + >>> print([c.sn for c in g.shortest_path_to_type("K7BP16-00040112", + ... "RFT thru", type_exclude = "correlator card slot")]) + ['K7BP16-00040112', 'CXS0279', 'CXA0018A', 'CXA0139B', 'SPL001AP2', + 'SPL001A', 'SPL001AP3', 'CXS0281', 'RFTQ15B'] The shortest path really was through the correlator card slot, until we explicitly rejected such paths. @@ -1238,8 +1240,8 @@ def _add_to_chain(chain, sn, prop, sever, fail_comp): if chain[-1] == "//": if len(chain) < 2: raise SyntaxError( - 'Confused about chain ending in "%s". Is the ' - "first serial number valid?" % (chain[-1]) + f'Confused about chain ending in "{chain[-1]}". ' + "Is the first serial number valid?" ) try: _add_to_sever(chain[-2]["comp"].sn, sn, sever, fail_comp) @@ -1248,32 +1250,31 @@ def _add_to_chain(chain, sn, prop, sever, fail_comp): del chain[-2] del chain[-1] - chain.append(dict()) + chain.append({}) try: chain[-1]["comp"] = component.get(sn=sn) for k in range(len(prop)): if len(prop[k].split("=")) != 2: - raise SyntaxError('Confused by the property command "%s".' % prop[k]) + raise SyntaxError(f'Confused by the property command "{prop[k]}".') chain[-1][prop[k].split("=")[0]] = prop[k].split("=")[1] except pw.DoesNotExist: - if not sn in fail_comp: + if sn not in fail_comp: fail_comp.append(sn) def _id_from_multi(cls, o): if isinstance(o, int): return o - elif isinstance(o, cls): + + if isinstance(o, cls): return o.id - else: - return cls.get(name=o).id + + return cls.get(name=o).id # Public Functions # ================ -from ._db_tables import connect_peewee_tables as connect_database - def enter_ltf(ltf, time=datetime.datetime.now(), notes=None, force=False): """Enter an LTF into the database. @@ -1297,9 +1298,9 @@ def enter_ltf(ltf, time=datetime.datetime.now(), notes=None, force=False): """ try: - with open(ltf, "r") as myfile: + with open(ltf) as myfile: ltf = myfile.readlines() - except IOError: + except OSError: try: ltf = ltf.splitlines() except AttributeError: @@ -1311,31 +1312,31 @@ def enter_ltf(ltf, time=datetime.datetime.now(), notes=None, force=False): chain.append([]) sever = [] i = 0 - for l in ltf: - if len(l) and l[0] == "#": + for line in ltf: + if len(line) and line[0] == "#": continue severed = False try: - if l.split()[1] == "//": + if line.split()[1] == "//": severed = True except IndexError: pass - if not len(l) or l.isspace() or severed or l[0:2] == "$$": + if not len(line) or line.isspace() or severed or line[0:2] == "$$": if severed: - _add_to_sever(l.split()[0], l.split()[2], sever, fail_comp) + _add_to_sever(line.split()[0], line.split()[2], sever, fail_comp) if multi_sn: - _add_to_chain(chain[i], multi_sn, prop, sever, fail_comp) + _add_to_chain(chain[i], multi_sn, multi_prop, sever, fail_comp) multi_sn = False chain.append([]) i += 1 continue - l = l.replace("\n", "") - l = l.strip() + line = line.replace("\n", "") + line = line.strip() - sn = l.split()[0] - prop = l.split()[1:] + sn = line.split()[0] + prop = line.split()[1:] # Check to see if this is a multiple-line SN. if multi_sn: @@ -1372,9 +1373,8 @@ def enter_ltf(ltf, time=datetime.datetime.now(), notes=None, force=False): fail_comp, False, DoesNotExist, - "The following component%s " - "%s not in the DB and must be added first" - % (_plural(fail_comp), _are(fail_comp)), + f"The following component{_plural(fail_comp)} " + f"{_are(fail_comp)} not in the DB and must be added first", ) conn_list = [] @@ -1384,16 +1384,13 @@ def enter_ltf(ltf, time=datetime.datetime.now(), notes=None, force=False): comp1 = c[i - 1]["comp"] comp2 = c[i]["comp"] if comp1.sn == comp2.sn: - logger.info( - "Skipping auto connexion: %s <=> %s." % (comp1.sn, comp2.sn) - ) + logger.info(f"Skipping auto connexion: {comp1.sn} <=> {comp2.sn}.") else: conn = connexion.from_pair(comp1, comp2) try: if conn.is_permanent(time): logger.info( - "Skipping permanent connexion: %s <=> %s." - % (comp1.sn, comp2.sn) + f"Skipping permanent connexion: {comp1.sn} <=> {comp2.sn}." ) elif conn not in conn_list: conn_list.append(conn) @@ -1407,7 +1404,7 @@ def enter_ltf(ltf, time=datetime.datetime.now(), notes=None, force=False): try: prop_list.append([comp, property_type.get(name=p), c[i][p]]) except pw.DoesNotExist: - raise DoesNotExist('Property type "%s" does not exist.' % p) + raise DoesNotExist(f'Property type "{p}" does not exist.') make_connexion(conn_list, time, False, notes, force) sever_connexion(sever, time, notes, force) for p in prop_list: @@ -1437,7 +1434,12 @@ def get_global_flag_times(flag): else: query_ = global_flag.select().where(global_flag.id == flag) - flag_ = query_.join(graph_obj).join(event).where(event.active == True).get() + flag_ = ( + query_.join(graph_obj) + .join(event) + .where(event.active == True) # noqa: E712 + .get() + ) event_ = event.get(graph_obj=flag_.id, active=True) @@ -1472,7 +1474,7 @@ def global_flags_between(start_time, end_time, severity=None): query = global_flag.select() if severity: query = query.where(global_flag.severity == severity) - query = query.join(graph_obj).join(event).where(event.active == True) + query = query.join(graph_obj).join(event).where(event.active == True) # noqa: E712 # Set aliases for the join ststamp = timestamp.alias() diff --git a/ch_util/ni_utils.py b/ch_util/ni_utils.py index 6f3c498c..77da5eed 100644 --- a/ch_util/ni_utils.py +++ b/ch_util/ni_utils.py @@ -1,12 +1,10 @@ """Tools for noise injection data""" import numpy as np -import os import datetime from numpy import linalg as LA from scipy import linalg as sciLA import warnings -import copy from caput import memh5 from caput import mpiarray @@ -222,10 +220,10 @@ def process_synced_data(data, ni_params=None, only_off=False): # Add noise source dataset gate_dset = newdata.create_dataset( - "gated_vis{0}".format(i + 1), data=vis_noise, distributed=dist + f"gated_vis{i + 1}", data=vis_noise, distributed=dist ) gate_dset.attrs["axis"] = np.array( - ["freq", "prod", "gated_time{0}".format(i + 1)] + ["freq", "prod", f"gated_time{i + 1}"] ) gate_dset.attrs["folding_period"] = folding_period gate_dset.attrs["folding_start"] = folding_start @@ -417,10 +415,12 @@ def _find_ni_params(data, verbose=0): if verbose: for i in range(N_ni_sources): - print("\nPWM signal from board %s is enabled" % ni_board[i]) - print("Period: %i GPU integrations" % ni_period) - print("High time: %i GPU integrations" % ni_high_time[i]) - print("FPGA offset: %i GPU integrations\n" % ni_offset[i]) + print("") + print(f"PWM signal from board {ni_board[i]} is enabled") + print(f"Period: {ni_period} GPU integrations") + print(f"High time: {ni_high_time[i]} GPU integrations") + print(f"FPGA offset: {ni_offset[i]} GPU integrations") + print("") # Number of fpga frames within a GPU integration int_period = data.attrs["gpu.gpu_intergration_period"][0] @@ -440,9 +440,7 @@ def _find_ni_params(data, verbose=0): for i in range(N_ni_sources) ] - ni_params = {"ni_period": ni_period, "ni_on_bins": ni_on_bins} - - return ni_params + return {"ni_period": ni_period, "ni_on_bins": ni_on_bins} def process_gated_data(data, only_off=False): @@ -544,10 +542,10 @@ def process_gated_data(data, only_off=False): ) memh5.copyattrs(data["gated_vis1"].attrs, gate_dset.attrs) - # The CHIME pipeline uses gpu.gpu_intergration_period to estimate the integration period - # for both the on and off gates. That number has to be changed (divided by 2) since - # with fast gating one integration period has 1/2 of data for the on gate and 1/2 - # for the off gate + # The CHIME pipeline uses gpu.gpu_intergration_period to estimate the + # integration period # for both the on and off gates. That number has to be + # changed (divided by 2) since with fast gating one integration period has + # 1/2 of data for the on gate and 1/2 for the off gate newdata.attrs["gpu.gpu_intergration_period"] = ( data.attrs["gpu.gpu_intergration_period"] // 2 ) @@ -555,7 +553,7 @@ def process_gated_data(data, only_off=False): return newdata -class ni_data(object): +class ni_data: """Provides analysis utilities for CHIME noise injection data. This is just a wrapper for all the utilities created in this module. @@ -592,12 +590,12 @@ def __init__(self, Reader_read_obj, Nadc_channels, adc_ch_ref=None, fbin_ref=Non self.Nadc_channels = Nadc_channels self.raw_vis = Reader_read_obj.vis self.Nfreqs = np.size(self.raw_vis, 0) # Number of frequencies - if adc_ch_ref != None: + if adc_ch_ref is not None: self.adc_ch_ref = adc_ch_ref else: self.adc_ch_ref = self.adc_channels[0] # Default reference channel - if fbin_ref != None: + if fbin_ref is not None: self.fbin_ref = fbin_ref else: # Default reference frequency bin (rather arbitrary) self.fbin_ref = self.Nfreqs // 3 @@ -650,7 +648,7 @@ def get_ni_gains(self, normalize_vis=False, masked_channels=None): """ self.channels = np.arange(self.Nadc_channels) - if masked_channels != None: + if masked_channels is not None: self.channels = np.delete(self.channels, masked_channels) self.Nchannels = len(self.channels) @@ -753,8 +751,7 @@ def utvec2mat(n, utvec): iu = np.triu_indices(n) A = np.zeros((n, n), dtype=np.complex128) A[iu] = utvec # Filling uppper triangle of A - A = A + np.triu(A, 1).conj().T # Filling lower triangle of A - return A + return A + np.triu(A, 1).conj().T # Filling lower triangle of A def ktrprod(A, B): @@ -1035,7 +1032,7 @@ def ni_gains_evalues_tf( Ntimeframes = np.size(vis_gated, 2) # Create NaN matrices to hold the gains and eigenvalues - gains = np.zeros((Nfreqs, Nchannels, Ntimeframes), dtype=np.complex) * ( + gains = np.zeros((Nfreqs, Nchannels, Ntimeframes), dtype=complex) * ( np.nan + 1j * np.nan ) evals = np.zeros((Nfreqs, Nchannels, Ntimeframes), dtype=np.float64) * np.nan @@ -1190,7 +1187,8 @@ def subtract_sky_noise(vis, Nchannels, timestamp, adc_ch_ref, fbin_ref): # Visibilities with noise off for cycle i vis_off_cycle_i = vis[:, :, index_end_on_cycle[i] + 1 : index_start_on_cycle[i]] - # New lines to find indices of maximum and minimum point of each cycle based on the reference channel + # New lines to find indices of maximum and minimum point of each cycle + # based on the reference channel index_max_i = auto_ref[ index_start_on_cycle[i] : index_end_on_cycle[i + 1] + 1 ].argmax() @@ -1201,8 +1199,10 @@ def subtract_sky_noise(vis, Nchannels, timestamp, adc_ch_ref, fbin_ref): vis_off_dec.append(vis_off_cycle_i[:, :, index_min_i]) # Instead of averaging all the data with noise on of a cycle, we take the median - # vis_on_dec.append(np.median(vis_on_cycle_i.real, axis=2)+1j*np.median(vis_on_cycle_i.imag, axis=2)) - # vis_off_dec.append(np.median(vis_off_cycle_i.real, axis=2)+1j*np.median(vis_off_cycle_i.imag, axis=2)) + # vis_on_dec.append(np.median(vis_on_cycle_i.real, axis=2)+1j + # * np.median(vis_on_cycle_i.imag, axis=2)) + # vis_off_dec.append(np.median(vis_off_cycle_i.real, axis=2)+1jr + # * np.median(vis_off_cycle_i.imag, axis=2)) timestamp_on_dec.append( np.mean(timestamp[index_start_on_cycle[i] : index_end_on_cycle[i + 1] + 1]) ) @@ -1267,7 +1267,7 @@ def gains2utvec_tf(gains): >>> from ch_util import andata >>> from ch_util import import ni_utils as ni - >>> data = andata.Reader('/scratch/k/krs/jrs65/chime_archive/20140916T173334Z_blanchard_corr/000[0-3]*.h5') + >>> data = andata.Reader('/data/20140916T173334Z_blanchard_corr/000[0-3]*.h5') >>> readdata = data.read() >>> nidata = ni.ni_data(readdata, 16) >>> nidata.get_ni_gains() @@ -1283,7 +1283,7 @@ def gains2utvec_tf(gains): Ntimeframes = np.size(gains, 2) # Number of time frames Nchannels = np.size(gains, 1) Ncorrprods = Nchannels * (Nchannels + 1) // 2 # Number of correlation products - G_ut = np.zeros((Nfreqs, Ncorrprods, Ntimeframes), dtype=np.complex) + G_ut = np.zeros((Nfreqs, Ncorrprods, Ntimeframes), dtype=complex) for f in range(Nfreqs): for t in range(Ntimeframes): diff --git a/ch_util/plot.py b/ch_util/plot.py index 2b9ee4f4..d8b9c872 100644 --- a/ch_util/plot.py +++ b/ch_util/plot.py @@ -1,7 +1,6 @@ """Plotting routines for CHIME data""" import numpy as np -import scipy as sp import matplotlib.pyplot as plt import warnings import datetime @@ -95,13 +94,13 @@ def waterfall( ax = fig.add_subplot(111) # Set title, if given: - if title != None: + if title is not None: ax.set_title(title) # Setting labels, if given: - if x_label != None: + if x_label is not None: ax.set_xlabel(x_label) - if y_label != None: + if y_label is not None: ax.set_ylabel(y_label) # Preparing data shape for plotting: @@ -112,7 +111,7 @@ def waterfall( tmstp = _select_time(data, time_sel) # Apply median filter, if 'med_filt' is given: - if med_filt != None: + if med_filt is not None: msg = "Warning: Wrong value for 'med_filt'. Ignoring argument." if med_filt[0] == "new": # Apply median filter: @@ -123,7 +122,7 @@ def waterfall( # Save baseline to file, if given: fileBaseOut = open(med_filt[4], "w") for ii in range(len(baseline)): - fileBaseOut.write("{0}\n".format(baseline[ii, 0])) + fileBaseOut.write(baseline[ii, 0] + "\n") fileBaseOut.close() elif med_filt[0] == "old": # Reshape baseline to ensure type and shape: @@ -134,7 +133,7 @@ def waterfall( print(msg) # Shape data to full day, if 'full_day' is given: - if full_day != None: + if full_day is not None: plt_data = _full_day_shape( plt_data, tmstp, @@ -148,7 +147,7 @@ def waterfall( wtfl = ax.imshow(plt_data[::-1, :], **kwargs) # Ajust aspect ratio of image if aspect is provided: - if aspect != None: + if aspect is not None: _force_aspect(ax, aspect) # Ajust colorbar size: @@ -161,22 +160,22 @@ def waterfall( cbar = fig.colorbar(wtfl) # Set label to colorbar, if given: - if cbar_label != None: + if cbar_label is not None: cbar.set_label(cbar_label) # Output depends on keyword arguments: - if show_plot == True: + if show_plot is True: plt.show() - elif (show_plot != None) and (show_plot != False): + elif (show_plot is not None) and (show_plot is not False): msg = ( 'Optional keyword argument "show_plot" should receive either' - ' "True" or "False". Received "{0}". Ignoring argument.'.format(show_plot) + f' "True" or "False". Received "{show_plot}". Ignoring argument.' ) warnings.warn(msg, SyntaxWarning) # Save to file if filename is provided: - if out_file != None: - if res != None: + if out_file is not None: + if res is not None: fig.savefig(out_file, dpi=res) else: fig.savefig(out_file) @@ -326,7 +325,7 @@ def _coerce_data_shape( # Selects what part to plot: # Defaults to plotting real part of data. - if part_sel == "real" or part_sel == None: + if part_sel == "real" or part_sel is None: data = data.real elif part_sel == "imag": data = data.imag @@ -340,7 +339,7 @@ def _coerce_data_shape( msg = ( 'Optional keyword argument "part_sel" has to receive' ' one of "real", "imag", "mag", "phase" or "complex".' - ' Received "{0}"'.format(part_sel) + f' Received "{part_sel}"' ) raise ValueError(msg) @@ -464,7 +463,7 @@ def _full_day_shape(data, tmstp, date, n_bins=8640, axis="solar", ax=None): break # Set azimuth ticks, if given: - if ax != None: + if ax is not None: tck_stp = n_bins / 6.0 ticks = np.array( [ @@ -506,7 +505,7 @@ def _full_day_shape(data, tmstp, date, n_bins=8640, axis="solar", ax=None): break # Set time ticks, if given: - if ax != None: + if ax is not None: tck_stp = n_bins / 6.0 ticks = np.array( [ @@ -523,7 +522,7 @@ def _full_day_shape(data, tmstp, date, n_bins=8640, axis="solar", ax=None): # Set label: ax.set_xlabel("Time (UTC-8 hours)") - print("Number of 10-second bins added to full day data: {0}".format(n_added)) + print(f"Number of 10-second bins added to full day data: {n_added}") # Set new array to NaN for subsequent masking: Z = np.ones((1024, n_bins)) @@ -534,7 +533,7 @@ def _full_day_shape(data, tmstp, date, n_bins=8640, axis="solar", ax=None): for ii in range(n_bins): n_col = len(values_to_sum[ii]) if n_col > 0: - col = np.zeros((1024)) + col = np.zeros(1024) for jj in range(n_col): col = col + data[:, values_to_sum[ii][jj]] Z[:, ii] = col / float(n_col) @@ -577,27 +576,29 @@ def _force_aspect(ax, aspect=1.0): def _med_filter(data, n_bins=200, i_bin=0, filt_window=37): - """Normalize a 2D array by its power spectrum averaged over 'n_bins' starting at 'i_bin'. + """Normalize a 2D array. + + The array is normalized by its power spectrum averaged over 'n_bins' + starting at 'i_bin'. Parameters ---------- data : numpy.ndarray - Data to be normalized - + Data to be normalized n_bins : integer - Number of bins over which to average the power spectrum - + Number of bins over which to average the power spectrum i_bin : integer - First bin of the range over which to average the power spectrum - + First bin of the range over which to average the power spectrum filt_window : integer - Width of the window for the median filter. The filter is applied - once with this window and a second time with 1/3 of this window width. + Width of the window for the median filter. The filter is applied + once with this window and a second time with 1/3 of this window width. Returns ------- - rel_power : 2d array normalized by average power spectrum (baseline) - medfilt_baseline : Average power spectrum + rel_power + 2d array normalized by average power spectrum (baseline) + medfilt_baseline + Average power spectrum Issues ------ diff --git a/ch_util/rfi.py b/ch_util/rfi.py index b5215c25..afb11151 100644 --- a/ch_util/rfi.py +++ b/ch_util/rfi.py @@ -27,7 +27,7 @@ import warnings import logging -from typing import Tuple, Optional, Union +from typing import Optional, Union import numpy as np import scipy.signal as sig @@ -41,7 +41,8 @@ logger.addHandler(logging.NullHandler()) -# Ranges of bad frequencies given by their start time (in unix time) and corresponding start and end frequencies (in MHz) +# Ranges of bad frequencies given by their start time (in unix time) and +# corresponding start and end frequencies (in MHz) # If the start time is not specified, t = [], the flag is applied to all CSDs BAD_FREQUENCIES = { "chime": [ @@ -309,7 +310,7 @@ def number_deviations( def get_autocorrelations( data, stack: bool = False, normalize: bool = False -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Extract autocorrelations from a data stack. Parameters @@ -715,7 +716,8 @@ def mad_cut_rolling( mfd = tools.invert_no_zero(nanmedian(data, axis=1)) data *= mfd[:, np.newaxis] - # Add NaNs around the edges of the array so that we don't have to treat them separately + # Add NaNs around the edges of the array so that we don't have to + # treat them separately eshp = [nfreq + fwidth - 1, ntime + twidth - 1] exp_data = np.full(eshp, np.nan, dtype=data.dtype) exp_data[foff : foff + nfreq, toff : toff + ntime] = data @@ -792,9 +794,7 @@ def highpass_delay_filter(freq, tau_cut, flag, epsilon=1e-10): np.sinc(2.0 * tau_cut * (freq[:, np.newaxis] - freq[np.newaxis, :])) / epsilon ) - pinv = np.linalg.pinv(cov * mflag, hermitian=True) * mflag - - return pinv + return np.linalg.pinv(cov * mflag, hermitian=True) * mflag def iterative_hpf_masking( diff --git a/ch_util/timing.py b/ch_util/timing.py index 493f394f..f6cc0af1 100644 --- a/ch_util/timing.py +++ b/ch_util/timing.py @@ -134,7 +134,8 @@ def from_dict(self, **kwargs): The coefficient of the spectral model of the amplitude variations of each of the noise source inputs versus time. weight_alpha: np.ndarray[nsource, ntime] - Estimate of the uncertainty (inverse variance) on the amplitude coefficients. + Estimate of the uncertainty (inverse variance) on the amplitude + coefficients. static_amp: np.ndarray[nfreq, nsource] The amplitude that was subtracted from each frequency and input prior to fitting for the amplitude variations. This is necessary to remove the @@ -143,7 +144,8 @@ def from_dict(self, **kwargs): Inverse variance on static_amp. num_freq: np.ndarray[nsource, ntime] The number of frequencies used to determine the delay and alpha quantities. - If num_freq is 0, then that time is ignored when deriving the timing correction. + If num_freq is 0, then that time is ignored when deriving the timing + correction. coeff_tau: np.ndarray[ninput, nsource] If coeff is provided, then the timing correction applied to a particular input will be the linear combination of the tau correction from the @@ -175,7 +177,7 @@ def from_dict(self, **kwargs): else: dset = tcorr.create_dataset(name, data=data) - dset.attrs["axis"] = np.array(spec["axis"], dtype=np.string_) + dset.attrs["axis"] = np.array(spec["axis"], dtype=np.bytes_) return tcorr @@ -243,12 +245,10 @@ def _interpret_and_read(cls, acq_files, start, stop, datasets, out_group): # Now concatenate the files. Dynamic datasets will be concatenated. # Static datasets will be extracted from the first file. - data = tod.concatenate( + return tod.concatenate( objs, out_group=out_group, start=start, stop=stop, datasets=datasets ) - return data - @property def freq(self): """Provide convenience access to the frequency bin centres.""" @@ -297,7 +297,7 @@ def weight_tau(self): dset = self.create_flag("weight_tau", data=weight_tau) dset.attrs["axis"] = np.array( - DSET_SPEC["weight_tau"]["axis"], dtype=np.string_ + DSET_SPEC["weight_tau"]["axis"], dtype=np.bytes_ ) return self.flags["weight_tau"] @@ -339,7 +339,7 @@ def weight_alpha(self): dset = self.create_flag("weight_alpha", data=weight_alpha) dset.attrs["axis"] = np.array( - DSET_SPEC["weight_alpha"]["axis"], dtype=np.string_ + DSET_SPEC["weight_alpha"]["axis"], dtype=np.bytes_ ) return self.flags["weight_alpha"] @@ -394,29 +394,34 @@ def has_coeff_alpha(self): @property def amp_to_delay(self): - """Return conversion from noise source amplitude variations to delay variations.""" + """This is the conversion from noise source amplitude variations + to delay variations.""" return self.attrs.get("amp_to_delay", None) @amp_to_delay.setter def amp_to_delay(self, val): - """Sets the conversion from noise source amplitude variations to delay variations. + """Set amp_to_delay + + Note that setting this quantity will result in the following + modification to the timing correction: - Note that setting this quantity will result in the following modification to the - timing correction: tau --> tau - amp_to_delay * alpha. This can be used to remove - variations introduced by the noise source distribution system from the timing correction - using the amplitude variations as a proxy for temperature. + tau --> tau - amp_to_delay * alpha + + This can be used to remove variations introduced by the noise + source distribution system from the timing correction using the + amplitude variations as a proxy for temperature. """ if self.has_coeff_alpha: raise AttributeError( "The amplitude variations are already being used to " "correct the delay variations through the coeff_alpha dataset." ) - elif val is not None: + if val is not None: self.attrs["amp_to_delay"] = val @amp_to_delay.deleter def amp_to_delay(self): - """Remove any conversion from noise source amplitude variations to delay variations.""" + """Remove amp_to_delay""" if "amp_to_delay" in self.attrs: del self.attrs["amp_to_delay"] @@ -431,8 +436,8 @@ def reference_noise_source(self): if "reference_noise_source" in self.datasets: iref = self.datasets["reference_noise_source"][:] return iref if np.unique(iref).size > 1 else iref[0] - else: - return self.zero_delay_noise_source + + return self.zero_delay_noise_source @property def zero_delay_noise_source(self): @@ -443,8 +448,8 @@ def zero_delay_noise_source(self): "Could not determine which input the delay template " "is referenced with respect to." ) - else: - return zero_tau[0] + + return zero_tau[0] def set_coeff( self, @@ -505,7 +510,7 @@ def set_coeff( dset = self.create_flag(name, data=coeff[:, reod]) else: dset = self.create_dataset(name, data=coeff[:, reod]) - dset.attrs["axis"] = np.array(spec["axis"], dtype=np.string_) + dset.attrs["axis"] = np.array(spec["axis"], dtype=np.bytes_) if reference_noise_source is not None: ref_sn_lookup = { @@ -525,7 +530,7 @@ def set_coeff( dset = self.create_flag(name, data=reference_reodered) else: dset = self.create_dataset(name, data=reference_reodered) - dset.attrs["axis"] = np.array(spec["axis"], dtype=np.string_) + dset.attrs["axis"] = np.array(spec["axis"], dtype=np.bytes_) self.create_index_map("input", inputs) @@ -573,23 +578,25 @@ def set_global_reference_time(self, tref, window=0.0, interpolate=False, **kwarg tref : unix time Reference the templates to the values at this time. window: float - Reference the templates to the median value over a window (in seconds) - around tref. If nonzero, this will override the interpolate keyword. + Reference the templates to the median value over a window (in + seconds) around tref. If nonzero, this will override the + interpolate keyword. interpolate : bool - Interpolate the delay template to time tref. Otherwise take the measured time - nearest to tref. The get_tau method is use to perform the interpolation, and - kwargs for that method will be passed along. + Interpolate the delay template to time tref. Otherwise take + the measured time nearest to tref. The get_tau method is use + to perform the interpolation, and kwargs for that method will + be passed along. """ tref = ctime.ensure_unix(tref) tref_string = ctime.unix_to_datetime(tref).strftime("%Y-%m-%d %H:%M:%S %Z") - logger.info("Referencing timing correction with respect to %s." % tref_string) + logger.info(f"Referencing timing correction with respect to {tref_string}.") if window > 0.0: iref = np.flatnonzero( (self.time >= (tref - window)) & (self.time <= (tref + window)) ) if iref.size > 0: logger.info( - "Using median of %d samples around reference time." % iref.size + f"Using median of {iref.size} samples around reference time." ) if self.has_num_freq: tau_ref = np.zeros((self.nsource, 1), dtype=self.tau.dtype) @@ -607,13 +614,11 @@ def set_global_reference_time(self, tref, window=0.0, interpolate=False, **kwarg else: raise ValueError( - "Timing correction not available for time %s." % tref_string + f"Timing correction not available for time {tref_string}." ) elif (tref < self.time[0]) or (tref > self.time[-1]): - raise ValueError( - "Timing correction not available for time %s." % tref_string - ) + raise ValueError(f"Timing correction not available for time {tref_string}.") else: if not interpolate: @@ -770,7 +775,7 @@ def get_tau(self, timestamp, ignore_amp=False, interp="linear", extrap_limit=Non else: logger.info( "Correcting delay template using amplitude template " - "with coefficient %0.1f." % self.amp_to_delay + f"with coefficient {self.amp_to_delay:.1f}." ) # Determine which input the delay template is referenced to @@ -823,7 +828,8 @@ def get_alpha(self, timestamp, interp="linear", extrap_limit=None): alpha: np.ndarray[nsource, ntime] Amplitude coefficient as a function of time for each of the noise sources. weight : np.ndarray[nsource, ntime] - The uncertainty on the amplitude coefficient, expressed as an inverse variance. + The uncertainty on the amplitude coefficient, expressed as an + inverse variance. """ flag = self.num_freq[:] > 0 if self.has_num_freq else None @@ -843,12 +849,13 @@ def get_alpha(self, timestamp, interp="linear", extrap_limit=None): def get_stacked_tau( self, timestamp, inputs, prod, reverse_stack, input_flags=None, **kwargs ): - """Return the appropriate delay for each stacked visibility at the requested time. + """Return the delay for each stacked visibility at the requested time. - Averages the delays from the noise source inputs that map to the set of redundant - baseline included in each stacked visibility. This yields the appropriate - common-mode delay correction. If input_flags is provided, then the bad inputs - that were excluded from the stack are also excluded from the delay template averaging. + Averages the delays from the noise source inputs that map to the + set of redundant baseline included in each stacked visibility. + This yields the appropriate common-mode delay correction. If + input_flags is provided, then the bad inputs that were excluded + from the stack are also excluded from the delay template averaging. Parameters ---------- @@ -910,16 +917,18 @@ def get_stacked_tau( def get_stacked_alpha( self, timestamp, inputs, prod, reverse_stack, input_flags=None, **kwargs ): - """Return the equivalent of `get_stacked_tau` for the noise source amplitude variations. - - Averages the alphas from the noise source inputs that map to the set of redundant - baseline included in each stacked visibility. If input_flags is provided, then the - bad inputs that were excluded from the stack are also excluded from the alpha - template averaging. This method can be used to generate a stacked alpha template - that can be used to correct a stacked tau template for variations in the noise source - distribution system. However, it is recommended that the tau template be corrected - before stacking. This is accomplished by setting the `amp_to_delay` property - prior to calling `get_stacked_tau`. + """Return the stacked alphas for the noise source amplitude variations. + + Averages the alphas from the noise source inputs that map to the + set of redundant baseline included in each stacked visibility. If + input_flags is provided, then the bad inputs that were excluded + from the stack are also excluded from the alpha template averaging. + This method can be used to generate a stacked alpha template that + can be used to correct a stacked tau template for variations in the + noise source distribution system. However, it is recommended that + the tau template be corrected before stacking. This is accomplished + by setting the `amp_to_delay` property prior to calling + `get_stacked_tau`. Parameters ---------- @@ -942,7 +951,8 @@ def get_stacked_alpha( Returns ------- alpha: np.ndarray[nstack, ntime] - Noise source amplitude variation as a function of time for each stacked visibility. + Noise source amplitude variation as a function of time for each + stacked visibility. """ if not self.has_amplitude: raise AttributeError( @@ -1077,11 +1087,14 @@ def _stack( return stacked_data def get_timing_correction(self, freq, timestamp, **kwargs): - """Return the phase correction from each noise source at the requested frequency and time. + """Return the phase correction from each noise source. + + Assumes the phase correction scales with frequency nu as + + phi = 2 pi nu tau - Assumes the phase correction scales with frequency nu as phi = 2 pi nu tau and uses the - get_tau method to interpolate over time. It acccepts and passes along keyword arguments - for that method. + and uses the get_tau method to interpolate over time. It acccepts and + passes along keyword arguments for that method. Parameters ---------- @@ -1093,9 +1106,11 @@ def get_timing_correction(self, freq, timestamp, **kwargs): Returns ------- gain: np.ndarray[nfreq, nsource, ntime] - Complex gain containing a pure phase correction for each of the noise sources. + Complex gain containing a pure phase correction for each of the + noise sources. weight: np.ndarray[nfreq, nsource, ntime] - Uncerainty on the gain for each of the noise sources, expressed as an inverse variance. + Uncerainty on the gain for each of the noise sources, + expressed as an inverse variance. """ tau, wtau = self.get_tau(timestamp, **kwargs) @@ -1230,14 +1245,16 @@ def apply_timing_correction(self, timestream, copy=False, **kwargs): Parameters ---------- timestream : andata.CorrData / equivalent or np.ndarray[nfreq, nprod, ntime] - If timestream is an np.ndarray containing the visiblities, then you - must also pass the corresponding freq, prod, input, and time axis as kwargs. - Otherwise these quantities are obtained from the attributes of CorrData. - If the visibilities have been stacked, then you must additionally pass the - stack and reverse_stack axis as kwargs, and (optionally) the input flags. + If timestream is an np.ndarray containing the visiblities, the + you must also pass the corresponding freq, prod, input, and + time axis as kwargs. Otherwise these quantities are obtained + from the attributes of CorrData. If the visibilities have been + stacked, then you must additionally pass the stack and + reverse_stack axis as kwargs, and (optionally) the input flags. copy : bool - Create a copy of the input visibilities. Apply the timing correction to - the copy and return it, leaving the original untouched. Default is False. + Create a copy of the input visibilities. Apply the timing + correction to the copy and return it, leaving the original + untouched. Default is False. freq : np.ndarray[nfreq, ] Frequency in MHz. Must be passed as keyword argument if timestream is an np.ndarray. @@ -1260,8 +1277,9 @@ def apply_timing_correction(self, timestream, copy=False, **kwargs): Must be passed as keyword argument if timestream is an np.ndarray and the visibilities have been stacked. input_flags : np.ndarray [ninput, ntime] - Array indicating which inputs were good at each time. Non-zero value - indicates that an input was good. Optional. Only used for stacked visibilities. + Array indicating which inputs were good at each time. Non-zero + value indicates that an input was good. Optional. Only used for + stacked visibilities. Returns ------- @@ -1271,8 +1289,9 @@ def apply_timing_correction(self, timestream, copy=False, **kwargs): else: None Correction is applied to the input visibility data. Also, - if timestream is an andata.CorrData instance and the gain dataset exists, - then it will be updated with the complex gains that have been applied. + if timestream is an andata.CorrData instance and the gain + dataset exists, then it will be updated with the complex + gains that have been applied. """ if isinstance(timestream, np.ndarray): is_obj = False @@ -1367,8 +1386,7 @@ def apply_timing_correction(self, timestream, copy=False, **kwargs): # If a copy was requested, then return the # new vis with phase correction applied - if copy: - return vis + return vis if copy else None def summary(self): """Provide a summary of the timing correction. @@ -1394,7 +1412,7 @@ def summary(self): fmt = "%-10s %10s %10s %15s %15s" hdr = fmt % ("", "PHI0", "TAU0", "SIGMA(TAU)", "SIGMA(TAU)") - per = fmt % ("", "", "", "@ %0.2f sec" % step, "@ %0.2f hr" % span) + per = fmt % ("", "", "", f"@ {step:.2f} sec", f"@ {span:.2f} hr") unt = fmt % ("INPUT", "[rad]", "[nsec]", "[psec]", "[psec]") line = "".join(["-"] * 65) summary = [line, hdr, per, unt, line] @@ -1470,7 +1488,7 @@ def from_acq_h5(cls, acq_files, only_correction=False, **kwargs): ) # Load the data into an andata.CorrData object - corr_data = super(TimingData, cls).from_acq_h5( + corr_data = super().from_acq_h5( acq_files, apply_gain=apply_gain, datasets=datasets, **kwargs ) @@ -1536,7 +1554,7 @@ def from_acq_h5(cls, acq_files, only_correction=False, **kwargs): # Create index map containing names of parameters param = ["intercept", "slope", "quad", "cube", "quart", "quint"] param = param[0 : res["static_phi_fit"].shape[0]] - data.create_index_map("param", np.array(param, dtype=np.string_)) + data.create_index_map("param", np.array(param, dtype=np.bytes_)) # Create datasets containing the timing correction for name, arr in res.items(): @@ -1546,7 +1564,7 @@ def from_acq_h5(cls, acq_files, only_correction=False, **kwargs): else: dset = data.create_dataset(name, data=arr) - dset.attrs["axis"] = np.array(spec["axis"], dtype=np.string_) + dset.attrs["axis"] = np.array(spec["axis"], dtype=np.bytes_) # Delete the temporary corr_data object del corr_data @@ -1567,7 +1585,7 @@ def summary(self): presented as quantiles over frequency for each of the noise source products. """ - summary = super(TimingData, self).summary() + summary = super().summary() vis = self.apply_timing_correction( self.vis[:], @@ -1623,7 +1641,7 @@ def summary(self): return summary -class TimingInterpolator(object): +class TimingInterpolator: """Interpolation that is aware of flagged data and weights. Flagged data is ignored during the interpolation. The weights from @@ -1728,7 +1746,7 @@ def __call__(self, xeval): def load_timing_correction( files, start=None, stop=None, window=43200.0, instrument="chime", **kwargs ): - """Find and load the appropriate timing correction for a list of corr acquisition files. + """Find and load the timing correction for a list of corr acquisition files. For example, if the instrument keyword is set to 'chime', then this function will accept all types of chime corr acquisition files, @@ -1773,7 +1791,7 @@ def load_timing_correction( if not acq_inst.startswith(instrument) or (acq_type != "corr"): raise RuntimeError( "This function is only able to parse corr type files " - "from the specified instrument (currently %s)." % instrument + f"from the specified instrument (currently {instrument})." ) # Search for all timing acquisitions on this node @@ -1781,9 +1799,10 @@ def load_timing_correction( glob.glob(os.path.join(node, "_".join(["*", instrument + "timing", acq_type]))) ) if not tdirs: - raise RuntimeError("No timing acquisitions found on node %s." % node) + raise RuntimeError(f"No timing acquisitions found on node {node}.") - # Determine the start time of the requested acquistion and the available timing acquisitions + # Determine the start time of the requested acquistion and the available + # timing acquisitions acq_start = ctime.datetime_to_unix(ctime.timestr_to_datetime(acq_date)) tacq_start = np.array( @@ -1794,7 +1813,7 @@ def load_timing_correction( # Find the closest timing acquisition to the requested acquisition iclose = np.argmin(np.abs(acq_start - tacq_start)) if np.abs(acq_start - tacq_start[iclose]) > 60.0: - raise RuntimeError("Cannot find appropriate timing acquisition for %s." % acq) + raise RuntimeError(f"Cannot find appropriate timing acquisition for {acq}.") # Grab all timing files from this acquisition tfiles = sorted(glob.glob(os.path.join(tdirs[iclose], "*.h5"))) @@ -1811,9 +1830,7 @@ def load_timing_correction( tstop = int(np.argmin(np.abs(time_stop - tdata.time))) # Load into TimingData object - data = TimingData.from_acq_h5(tfiles, start=tstart, stop=tstop, **kwargs) - - return data + return TimingData.from_acq_h5(tfiles, start=tstart, stop=tstop, **kwargs) # ancillary functions @@ -1909,22 +1926,22 @@ def construct_delay_template( time averaged phase versus frequency. Default is 2. static_phi: np.ndarray[nfreq, nsource] Subtract this quantity from the noise source phase prior to fitting - for the timing correction. If None, then this will be estimated from the median - of the noise source phase over time. + for the timing correction. If None, then this will be estimated from + the median of the noise source phase over time. weight_static_phi: np.ndarray[nfreq, nsource] - Inverse variance of the time averaged phased. Set to zero for frequencies and inputs - that are missing or should be ignored. If None, then this will be estimated from the - residuals of the fit. + Inverse variance of the time averaged phased. Set to zero for + frequencies and inputs that are missing or should be ignored. + If None, then this will be estimated from the residuals of the fit. static_phi_fit: np.ndarray[nparam, nsource] Polynomial fit to static_phi versus frequency. static_amp: np.ndarray[nfreq, nsource] Subtract this quantity from the noise source amplitude prior to fitting - for the amplitude variations. If None, then this will be estimated from the median - of the noise source amplitude over time. + for the amplitude variations. If None, then this will be estimated from + the median of the noise source amplitude over time. weight_static_amp: np.ndarray[nfreq, nsource] - Inverse variance of the time averaged amplitude. Set to zero for frequencies and inputs - that are missing or should be ignored. If None, then this will be estimated from the - residuals of the fit. + Inverse variance of the time averaged amplitude. Set to zero for + frequencies and inputs that are missing or should be ignored. If None, + then this will be estimated from the residuals of the fit. Returns ------- @@ -2042,7 +2059,7 @@ def construct_delay_template( # If requested discard frequencies and times that have high frac_lost if hasattr(data, "flags") and ("frac_lost" in data.flags): - logger.info("Fraction of data kept must be greater than %0.2f." % min_frac_kept) + logger.info(f"Fraction of data kept must be greater than {min_frac_kept:.2f}.") frac_kept = 1.0 - data.flags["frac_lost"][:].view(np.ndarray) flg &= frac_kept[:, np.newaxis, :] >= min_frac_kept @@ -2104,7 +2121,8 @@ def construct_delay_template( phi = phi[:].view(np.ndarray) ww = ww[:].view(np.ndarray) - # If a frequency is flagged more than `threshold` fraction of the time, then flag it entirely + # If a frequency is flagged more than `threshold` fraction of the time, + # then flag it entirely ww *= ( ( np.sum(ww > 0.0, axis=-1, dtype=np.float32, keepdims=True) @@ -2113,13 +2131,14 @@ def construct_delay_template( > threshold ).astype(np.float32) + fraction = ( + 100.0 + * np.sum(np.any(ww > 0.0, axis=(1, 2)), dtype=np.float32) + / float(ww.shape[0]) + ) logger.info( - "%0.1f percent of frequencies will be used to construct timing correction." - % ( - 100.0 - * np.sum(np.any(ww > 0.0, axis=(1, 2)), dtype=np.float32) - / float(ww.shape[0]), - ) + f"{fraction:.1f} percent of frequencies will be used " + "to construct timing correction." ) # If the starting values for the mean and variance were not provided, @@ -2208,12 +2227,12 @@ def construct_delay_template( if check_amp: nsigma = np.abs(ramp) * np.sqrt(weight_static_amp[:, :, np.newaxis]) not_outlier *= (nsigma < nsigma_amp[iter_weight]).astype(np.float32) - msg.append("nsigma_amp = %0.1f" % nsigma_amp[iter_weight]) + msg.append(f"nsigma_amp = {nsigma_amp[iter_weight]:.1f}") if check_phi: nsigma = np.abs(rphi) * np.sqrt(weight_static_phi[:, :, np.newaxis]) not_outlier *= (nsigma < nsigma_phi[iter_weight]).astype(np.float32) - msg.append("nsigma_phi = %0.1f" % nsigma_phi[iter_weight]) + msg.append(f"nsigma_phi = {nsigma_phi[iter_weight]:.1f}") if check_amp or check_phi: ww *= not_outlier @@ -2272,22 +2291,22 @@ def construct_delay_template( data.redistribute("freq") # Return results - return dict( - tau=tau, - alpha=alpha, - weight_tau=weight_tau, - weight_alpha=weight_alpha, - static_phi=static_phi, - static_amp=static_amp, - weight_static_phi=weight_static_phi, - weight_static_amp=weight_static_amp, - static_phi_fit=static_phi_fit, - num_freq=num_freq, - phi=phi, - amp=amp, - weight_phi=weight_phi, - weight_amp=weight_amp, - ) + return { + "tau": tau, + "alpha": alpha, + "weight_tau": weight_tau, + "weight_alpha": weight_alpha, + "static_phi": static_phi, + "static_amp": static_amp, + "weight_static_phi": weight_static_phi, + "weight_static_amp": weight_static_amp, + "static_phi_fit": static_phi_fit, + "num_freq": num_freq, + "phi": phi, + "amp": amp, + "weight_phi": weight_phi, + "weight_amp": weight_amp, + } def map_input_to_noise_source(inputs, noise_sources): @@ -2340,13 +2359,11 @@ def count_startswith(x, y): source_names = list(map(parse_serial, noise_sources["correlator_input"])) # Map each input to a noise source - imap = [ + return [ np.argmax([count_startswith(inp, src) for src in source_names]) for inp in input_names ] - return imap - def eigen_decomposition(vis, flag): """Eigenvalue decomposition of the visibility matrix. @@ -2459,8 +2476,8 @@ def fit_poly_to_phase(freq, resp, resp_error, nparam=2): _func_poly_phase, x, y, p0=p0.copy(), sigma=err, absolute_sigma=False ) - except Exception as excep: - logger.warning("Nonlinear phase fit failed with error: %s" % excep) + except (ValueError, RuntimeError) as excep: + logger.warning(f"Nonlinear phase fit failed with error: {excep}") # Fit failed, return the initial parameter estimates popt = p0 pcov = np.zeros((nparam, nparam), dtype=np.float64) @@ -2573,14 +2590,12 @@ def _search_nearest(x, xeval): index_previous = np.maximum(0, index_next - 1) index_next = np.minimum(x.size - 1, index_next) - index = np.where( + return np.where( np.abs(xeval - x[index_previous]) < np.abs(xeval - x[index_next]), index_previous, index_next, ) - return index - def _interpolation_nearest(x, y, var, xeval): index = _search_nearest(x, xeval) diff --git a/ch_util/tools.py b/ch_util/tools.py index b79669fd..91077f9c 100644 --- a/ch_util/tools.py +++ b/ch_util/tools.py @@ -52,15 +52,20 @@ Fetch the inputs for blanchard during layout 38:: >>> from datetime import datetime - >>> inputs = get_correlator_inputs(datetime(2016,05,23,00), correlator='pathfinder') + >>> inputs = get_correlator_inputs(datetime(2016,05,23,00), + ... correlator='pathfinder') >>> inputs[1] - CHIMEAntenna(id=1, reflector=u'W_cylinder', antenna=u'ANT0123B', powered=True, pos=9.071800000000001, input_sn=u'K7BP16-00040401', pol=u'S', corr=u'K7BP16-0004', cyl=0) + CHIMEAntenna(id=1, reflector=u'W_cylinder', antenna=u'ANT0123B', + powered=True, pos=9.071800000000001, input_sn=u'K7BP16-00040401', + pol=u'S', corr=u'K7BP16-0004', cyl=0) >>> print "NS position:", inputs[1].pos NS position: 9.0718 >>> print "Polarisation:", inputs[1].pol Polarisation: S >>> inputs[3] - CHIMEAntenna(id=3, reflector=u'W_cylinder', antenna=u'ANT0128B', powered=True, pos=9.681400000000002, input_sn=u'K7BP16-00040403', pol=u'S', corr=u'K7BP16-0004', cyl=0) + CHIMEAntenna(id=3, reflector=u'W_cylinder', antenna=u'ANT0128B', + powered=True, pos=9.681400000000002, input_sn=u'K7BP16-00040403', + pol=u'S', corr=u'K7BP16-0004', cyl=0) Housekeeping Inputs =================== @@ -122,18 +127,23 @@ - :py:meth:`ensure_list` """ +from __future__ import annotations + import datetime import numpy as np import scipy.linalg as la import re -from typing import Tuple from caput import pfb -from caput.interferometry import projected_distance, fringestop_phase +from caput.interferometry import ( + projected_distance, + fringestop_phase as fringestop_phase, +) import ch_ephem.observers -# All telescope geometry (rotation, roll, offsets) moved to ch_ephem/instruments.yaml +# All telescope geometry (rotation, roll, offsets) moved to +# ch_ephem/instruments.yaml # # To access them: # @@ -155,7 +165,7 @@ # ======= -class HKInput(object): +class HKInput: """A housekeeping input. Parameters @@ -197,7 +207,7 @@ def __repr__(self): return ret -class CorrInput(object): +class CorrInput: """Base class for describing a correlator input. Meant to be subclassed by actual types of inputs. @@ -241,16 +251,14 @@ def _attribute_strings(self): for k in ["id", "crate", "slot", "sma", "corr_order", "delay"] ] - kv = ["%s=%s" % (k, repr(v)) for k, v in prop if v is not None] + [ - "%s=%s" % (k, repr(v)) for k, v in self.__dict__.items() if k[0] != "_" + return [k + "=" + repr(v) for k, v in prop if v is not None] + [ + k + "=" + repr(v) for k, v in self.__dict__.items() if k[0] != "_" ] - return kv - def __repr__(self): kv = self._attribute_strings() - return "%s(%s)" % (self.__class__.__name__, ", ".join(kv)) + return self.__class__.name__ + "(" + ", ".join(kv) + ")" @property def id(self): @@ -264,8 +272,8 @@ def id(self): """ if hasattr(self, "_id"): return self._id - else: - return serial_to_id(self.input_sn) + + return serial_to_id(self.input_sn) @id.setter def id(self, val): @@ -365,10 +373,10 @@ class ArrayAntenna(Antenna): flag = None def _attribute_strings(self): - kv = super(ArrayAntenna, self)._attribute_strings() + kv = super()._attribute_strings() if self.pos is not None: - pos = ", ".join(["%0.2f" % pp for pp in self.pos]) - kv.append("pos=[%s]" % pos) + pos = ", ".join([f"{pp:.2f}" for pp in self.pos]) + kv.append(f"pos=[{pos}]") return kv @property @@ -387,8 +395,7 @@ def pos(self): return pos - else: - return None + return None @pos.setter def pos(self, val): @@ -501,10 +508,8 @@ class HolographyAntenna(Antenna): def _ensure_graph(graph): from . import layout - try: - graph.sg_spec - except: - graph = layout.graph(graph) + if not isinstance(graph, layout.graph): + return layout.graph(graph) return graph @@ -531,21 +536,17 @@ def _get_feed_position(lay, rfl, foc, cas, slt, slot_factor): pos : list x,y,z coordinates of the feed relative to the centre of the focal line. """ - try: - pos = [0.0] * 3 - - for node in [rfl, foc, cas, slt]: - prop = lay.node_property(node) + pos = [0.0] * 3 - for ind, dim in enumerate(["x_offset", "y_offset", "z_offset"]): - if dim in prop: - pos[ind] += float(prop[dim].value) # in metres + for node in [rfl, foc, cas, slt]: + prop = lay.node_property(node) - if "y_offset" not in lay.node_property(slt): - pos[1] += (float(slt.sn[-1]) - slot_factor) * 0.3048 + for ind, dim in enumerate(["x_offset", "y_offset", "z_offset"]): + if dim in prop: + pos[ind] += float(prop[dim].value) # in metres - except: - pos = None + if "y_offset" not in lay.node_property(slt): + pos[1] += (float(slt.sn[-1]) - slot_factor) * 0.3048 return pos @@ -607,7 +608,8 @@ def find(name): if rft is not None: break - # If the antenna does not exist, it might be the RFI antenna, the noise source, or empty + # If the antenna does not exist, it might be the RFI antenna, + # the noise source, or empty if ant is None: if rfi_antenna is not None: rfl = lay.closest_of_type( @@ -632,18 +634,17 @@ def find(name): return Blank(id=chan_id, input_sn=corr_input.sn, corr=corr_sn) # Determine polarization from antenna properties + keydict = { + "H": "hpol_orient", + "V": "vpol_orient", + "1": "pol1_orient", + "2": "pol2_orient", + } + + pkey = keydict[pol.sn[-1]] try: - keydict = { - "H": "hpol_orient", - "V": "vpol_orient", - "1": "pol1_orient", - "2": "pol2_orient", - } - - pkey = keydict[pol.sn[-1]] pdir = lay.node_property(ant)[pkey].value - - except: + except KeyError: pdir = None # Determine serial number of RF thru @@ -713,18 +714,18 @@ def find(name): flag=flag, ) - elif cyl == 0 or cyl == 1: + if cyl == 0 or cyl == 1: # Dealing with a pathfinder feed # Determine y_offset - try: - pos = [0.0] * 3 + pos = [0.0] * 3 - pos[0] = cyl * _PF_SPACE + pos[0] = cyl * _PF_SPACE - cas_prop = lay.node_property(cas) - slt_prop = lay.node_property(slt) + cas_prop = lay.node_property(cas) + slt_prop = lay.node_property(slt) + try: d1 = float(cas_prop["dist_to_n_end"].value) / 100.0 # in metres d2 = float(slt_prop["dist_to_edge"].value) / 100.0 # in metres orient = cas_prop["slot_zero_pos"].value @@ -733,8 +734,7 @@ def find(name): # Turn into distance increasing from South to North. pos[1] = 20.0 - pos[1] - - except: + except KeyError: pos = None # Try and determine if the FLA is powered or not. Paths without an @@ -763,7 +763,7 @@ def find(name): flag=flag, ) - elif cyl == 6: + if cyl == 6: # Dealing with an KKO feed # Determine position @@ -785,7 +785,7 @@ def find(name): flag=flag, ) - elif cyl == 7: + if cyl == 7: # Dealing with a GBO feed # Determine position @@ -807,7 +807,7 @@ def find(name): flag=flag, ) - elif cyl == 8: + if cyl == 8: # Dealing with a HCO feed # Determine position @@ -829,6 +829,8 @@ def find(name): flag=flag, ) + raise RuntimeError("Fell out of the bottom of _get_input_props!") + # Public Functions # ================ @@ -963,7 +965,7 @@ def sensor_to_hk(graph, comp): return HKInput(atmel, chan, int(mux.sn[-2])) - elif comp.type.name == "FLA" or comp.type.name == "RFT thru": + if comp.type.name == "FLA" or comp.type.name == "RFT thru": if comp.type.name == "FLA": try: comp = graph.neighbour_of_type(comp, "RFT thru")[0] @@ -980,8 +982,8 @@ def sensor_to_hk(graph, comp): ) return HKInput(atmel, int(hydra.sn[-1]), None) - else: - raise ValueError("You can only pass components of type LNA, FLA or RFT thru.") + + raise ValueError("You can only pass components of type LNA, FLA or RFT thru.") def hk_to_sensor(graph, inp): @@ -1093,9 +1095,7 @@ def parse_chime_serial(sn): mo = re.match("FCC(\d{2})(\d{2})(\d{2})", sn) if mo is None: - raise RuntimeError( - "Serial number %s does not match expected CHIME format." % sn - ) + raise RuntimeError(f"Serial number {sn} does not match expected CHIME format.") crate = int(mo.group(1)) slot = int(mo.group(2)) @@ -1130,7 +1130,7 @@ def parse_pathfinder_serial(sn): if mo is None: raise RuntimeError( - "Serial number %s does not match expected Pathfinder format." % sn + f"Serial number {sn} does not match expected Pathfinder format." ) crate = mo.group(1) @@ -1164,7 +1164,7 @@ def parse_old_serial(sn): if mo is None: raise RuntimeError( - "Serial number %s does not match expected 8/16 channel format." % sn + f"Serial number {sn} does not match expected 8/16 channel format." ) slot = mo.group(1) @@ -1207,8 +1207,7 @@ def get_pathfinder_channel(slot, sma): 96, 32, ] - channel = c[slot] + sma if slot > 0 else sma - return channel + return c[slot] + sma if slot > 0 else sma # Determine ID try: @@ -1283,7 +1282,7 @@ def get_crate_channel(slot, sma): return default -def get_default_frequency_map_stream() -> Tuple[np.ndarray]: +def get_default_frequency_map_stream() -> tuple[np.ndarray]: """Get the default CHIME frequency map stream. Level order is [shuffle, crate, slot, link]. @@ -1332,7 +1331,7 @@ def order_frequency_map_stream(fmap: np.ndarray, stream_id: np.ndarray) -> np.nd shuffle, crate, slot, link for each frequency """ - def decode_stream_id(sid: int) -> Tuple[int]: + def decode_stream_id(sid: int) -> tuple[int]: link = sid & 15 slot = (sid >> 4) & 15 crate = (sid >> 8) & 15 @@ -1348,9 +1347,7 @@ def decode_stream_id(sid: int) -> Tuple[int]: x[f].append(decoded_stream[ii]) # TODO: maybe implement some checks here - stream = np.array([i[0] for i in x], dtype=np.int32) - - return stream + return np.array([i[0] for i in x], dtype=np.int32) def get_correlator_inputs(lay_time, correlator=None, connect=True): @@ -1410,7 +1407,7 @@ def get_correlator_inputs(lay_time, correlator=None, connect=True): # A hack to return GBO correlator inputs correlator = "tone" connect = False - laytime = 0 + lay_time = 0 return fake_tone_database() if not connect_this_rank(): @@ -1423,12 +1420,13 @@ def get_correlator_inputs(lay_time, correlator=None, connect=True): # Fetch layout_tag start time if we received a layout num if isinstance(lay_time, int): raise ValueError("Layout IDs are no longer supported.") - elif isinstance(lay_time, datetime.datetime): + + if isinstance(lay_time, datetime.datetime): layout_graph = layout.graph.from_db(lay_time) elif isinstance(lay_time, layout.graph): layout_graph = lay_time else: - raise ValueError("Unsupported argument lay_time=%s" % repr(lay_time)) + raise ValueError(f"Unsupported argument lay_time={lay_time!r}") # Fetch all the input components inputs = [] @@ -1448,7 +1446,7 @@ def get_correlator_inputs(lay_time, correlator=None, connect=True): try: corr = layout_graph.component(correlator) except layout.NotFound: - raise ValueError("Unknown correlator %s" % correlator) + raise ValueError("Unknown correlator: " + correlator) # Cut out SMA coaxes so we don't go outside of the correlator sg = set(layout_graph.nodes()) @@ -1637,11 +1635,10 @@ def get_feed_polarisations(feeds): Returns ------- pol : np.ndarray - Array of characters giving polarisation. If not an array feed returns '0'. + Array of characters giving polarisation. If not an array feed returns + '0'. """ - pol = np.array([(f.pol if is_array(f) else "0") for f in feeds]) - - return pol + return np.array([(f.pol if is_array(f) else "0") for f in feeds]) def is_array(feed): @@ -1765,7 +1762,10 @@ def get_noise_channel(inputs): def is_array_on(inputs, *args): - """Check if inputs are attached to an array antenna AND powered on AND flagged as good. + """Check if inputs are on. + + An input is on if they are attached to an array antenna AND powered on AND + flagged as good. Parameters ---------- @@ -1790,8 +1790,7 @@ def is_array_on(inputs, *args): ) # Assume that the argument is a sequence otherwise - else: - return [is_array_on(inp) for inp in inputs] + return [is_array_on(inp) for inp in inputs] # Create an is_chime_on alias for backwards compatibility @@ -1831,14 +1830,15 @@ def reorder_correlator_inputs(input_map, corr_inputs): def redefine_stack_index_map(input_map, prod, stack, reverse_stack): - """Ensure that only baselines between array antennas are used to represent the stack. + """Ensure only baselines between array antennas are used to represent the stack. - The correlator will have inputs that are not connected to array antennas. These inputs - are flagged as bad and are not included in the stack, however, products that contain - their `chan_id` can still be used to represent a characteristic baseline in the `stack` - index map. This method creates a new `stack` index map that, if possible, only contains - products between two array antennas. This new `stack` index map should be used when - calculating baseline distances to fringestop stacked data. + The correlator will have inputs that are not connected to array antennas. + These inputs are flagged as bad and are not included in the stack, however, + products that contain their `chan_id` can still be used to represent a + characteristic baseline in the `stack` index map. This method creates a new + `stack` index map that, if possible, only contains products between two + array antennas. This new `stack` index map should be used when calculating + baseline distances to fringestop stacked data. Parameters ---------- @@ -1848,7 +1848,8 @@ def redefine_stack_index_map(input_map, prod, stack, reverse_stack): prod : np.ndarray[nprod,] of dtype=('input_a', 'input_b') The correlation products as pairs of inputs. stack : np.ndarray[nstack,] of dtype=('prod', 'conjugate') - The index into the `prod` axis of a characteristic baseline included in the stack. + The index into the `prod` axis of a characteristic baseline included + in the stack. reverse_stack : np.ndarray[nprod,] of dtype=('stack', 'conjugate') The index into the `stack` axis that each `prod` belongs. @@ -1898,8 +1899,8 @@ def cmap(i, j, n): """ if i <= j: return (n * (n + 1) // 2) - ((n - i) * (n - i + 1) // 2) + (j - i) - else: - return cmap(j, i, n) + + return cmap(j, i, n) def icmap(ix, n): @@ -1990,7 +1991,8 @@ def unpack_product_array(prod_arr, axis=1, feeds=None): def pack_product_array(exp_arr, axis=1): """Pack full correlation matrices into upper triangular form. - It replaces the two feed axes of the matrix, with a single upper triangle product axis. + It replaces the two feed axes of the matrix, with a single upper + triangle product axis. Parameters @@ -2081,7 +2083,8 @@ def rankN_approx(A, rank=1): def eigh_no_diagonal(A, niter=5, eigvals=None): """Eigenvalue decomposition ignoring the diagonal elements. - The diagonal elements are iteratively replaced with those from a rank=1 approximation. + The diagonal elements are iteratively replaced with those from a rank=1 + approximation. Parameters ---------- @@ -2326,8 +2329,8 @@ def fringestop_time( if timestream.shape != expected_shape: raise ValueError( - "The shape of the timestream %s does not match the expected shape %s" - % (timestream.shape, expected_shape) + f"The shape of the timestream ({timestream.shape}) " + f"does not match the expected shape: {expected_shape}" ) delays = delay( @@ -2414,8 +2417,8 @@ def decorrelation( if timestream.shape[1:] != expected_shape: raise ValueError( - "The shape of the timestream %s does not match the expected shape %s" - % (timestream.shape, expected_shape) + f"The shape of the timestream ({timestream.shape}) " + f"does not match the expected shape: {expected_shape}" ) delays = delay( @@ -2548,8 +2551,7 @@ def beam_index2number(beam_index): """ beam_ew_index = beam_index // 256 beam_ns_index = beam_index % 256 - beam_number = 1000 * beam_ew_index + beam_ns_index - return beam_number + return 1000 * beam_ew_index + beam_ns_index def invert_no_zero(*args, **kwargs): @@ -2558,7 +2560,7 @@ def invert_no_zero(*args, **kwargs): import warnings warnings.warn( - f"Function invert_no_zero is deprecated - use 'caput.tools.invert_no_zero'", + "Function invert_no_zero is deprecated - use 'caput.tools.invert_no_zero'", category=DeprecationWarning, ) return tools.invert_no_zero(*args, **kwargs) diff --git a/doc/conf.py b/doc/conf.py index a89fe6f0..51555c8a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # ch_util documentation build configuration file, created by # sphinx-quickstart on Thu Oct 10 12:52:16 2013. diff --git a/pyproject.toml b/pyproject.toml index 0ee57e54..a268a9e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,12 @@ test = [ "pytest >= 7.0" ] +[tool.ruff] +lint.select = ["E", "F", "UP", "NPY", "BLE", "C4", "RET"] +lint.ignore = [ + "UP007" # Can't convert to "X | Y" type hinting until 3.10 +] + [tool.setuptools.package-data] "ch_util.catalogs" = ["*.json"] diff --git a/scripts/scan2txt.py b/scripts/scan2txt.py index 59550b97..676f1b79 100755 --- a/scripts/scan2txt.py +++ b/scripts/scan2txt.py @@ -16,7 +16,7 @@ # For echoing stdout to a file. -class tee_stdout(object): +class tee_stdout: def __init__(self, stdout, fp): self.stdout = stdout self.fp = fp @@ -37,10 +37,11 @@ def write(self, txt): def parse_cmd(x): if not x: return False + if not re.match(r"\$CMD\$.+", x): return False - else: - return x[5:] + + return x[5:] def test_cmd_done(x): @@ -88,40 +89,41 @@ def get_barcode(msg, test=None, fmt="", override=True): overridden = False for i in range(0, 3): - x = input(" Scan %s barcode: " % msg) + x = input(f" Scan {msg} barcode: ") if parse_cmd(x): if parse_cmd(x) == CMD_OVERRIDE: return input(" Override in effect. Enter any barcode: ") - else: - return x - elif test: + + return x + + if test: if test(x): return x - else: - x1 = x - print(' Was expecting format "%s"!' % fmt) - if override: - x = input(" Same barcode to over-ride, or correct: ") + + x1 = x + print(f' Was expecting format "{fmt}"!') + if override: + x = input(" Same barcode to over-ride, or correct: ") + if x == x1: + x = input(" Confirm over-ride by entering once more: ") if x == x1: - x = input(" Confirm over-ride by entering once more: ") - if x == x1: - print(" Over-ride successful.") - overridden = True - return x - elif test(x): - return x - else: - print( - " Bad over-ride and barcode not of format " - '"%s". Try again.' % fmt - ) - elif test(x): + print(" Over-ride successful.") + overridden = True + return x + if test(x): return x - else: - print( - " Bad over-ride and barcode not of format " - '"%s". Try again.' % fmt - ) + + print( + " Bad over-ride and barcode not of format " + f'"{fmt}". Try again.' + ) + elif test(x): + return x + else: + print( + " Bad over-ride and barcode not of format " + f'"{fmt}". Try again.' + ) else: return x print(" Giving up and cancelling current scan.") @@ -190,12 +192,10 @@ def get_can(): def commit_chain(fp, cmd, chain, name): if not parse_cmd(cmd) == CMD_CANCEL: - fp.write( - "# Chain type: %s. Scanned: %s.\n" - % (name, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) - ) + now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + fp.write(f"# Chain type: {name}. Scanned: {now}.\n") for x in chain: - fp.write("%s\n" % x) + fp.write(x + "\n") fp.write("\n") fp.flush() print(" Chain written to disc.") @@ -218,9 +218,9 @@ def get_freeform_chain(fp): if parse_cmd(x) == CMD_DONE or parse_cmd(x) == CMD_DONE_QUIT: commit_chain(fp, x, chain, "free form") break - else: - print(" Cancelling.") - return + + print(" Cancelling.") + return if len(chain): if x == chain[-1]: print( @@ -233,8 +233,8 @@ def get_freeform_chain(fp): x = input(" Chain complete. Enter anything to commit, or CANCEL: ") commit_chain(fp, x, chain, "free form") break - else: - chain.append(x) + + chain.append(x) else: chain.append(x) else: @@ -245,31 +245,32 @@ def get_freeform_chain(fp): def get_chain(fp, get_list, name): chain = [] - print("Starting chain of type: %s." % name) + print(f"Starting chain of type: {name}.") print("-" * 80) - for l in get_list: + for item in get_list: while 1: - x = l() + x = item() if parse_cmd(x) or not x: if parse_cmd(x) == CMD_DONE or parse_cmd(x) == CMD_DONE_QUIT: print( " Preset chains cannot be fininshed early. Finish or cancel." ) continue - else: - print(" Cancelling chain.") - return False + + print(" Cancelling chain.") + return False chain.append(x) print(" Chain so far: ", chain) print() - if overridden == False: + if not overridden: break x = input(" Chain complete. Enter anything to commit, or CANCEL: ") commit_chain(fp, x, chain, name) + return None if len(sys.argv) != 2: - print("Usage: %s " % sys.argv[0]) + print(f"Usage: {sys.argv[0]} ") exit() # Predefined chains. @@ -286,7 +287,8 @@ def get_chain(fp, get_list, name): # sys.stdout = tee_stdout(sys.stdout, fp_log) sys.stdin = tee_stdout(sys.stdin, fp_log) -fp.write("USER %s\n" % input("Please enter/scan your name: ")) +name = input("Please enter/scan your name: ") +fp.write(f"USER {name}\n") fp.flush() while 1: diff --git a/scripts/singleuse/generate_archive_test_data_2_X.py b/scripts/singleuse/generate_archive_test_data_2_X.py index 26c8cde0..818913ed 100644 --- a/scripts/singleuse/generate_archive_test_data_2_X.py +++ b/scripts/singleuse/generate_archive_test_data_2_X.py @@ -53,7 +53,7 @@ def dset_filter(dataset): for ii, d in enumerate(data_list): out_f = h5py.File(OUT_FILENAMES[ii], "w") - tdata = andata.concatenate( + andata.concatenate( [d], start=STARTS[ii], stop=STOPS[ii], diff --git a/scripts/singleuse/generate_archive_test_data_3_X.py b/scripts/singleuse/generate_archive_test_data_3_X.py index 74181a95..79f9d568 100644 --- a/scripts/singleuse/generate_archive_test_data_3_X.py +++ b/scripts/singleuse/generate_archive_test_data_3_X.py @@ -44,7 +44,7 @@ def dset_filter(dataset): for ii, d in enumerate(data_list): out_f = h5py.File(OUT_FILENAMES[ii], "w") - tdata = andata.concatenate( + andata.concatenate( [d], start=STARTS[ii], stop=STOPS[ii], diff --git a/scripts/update_psrcat.py b/scripts/update_psrcat.py index 307fdd87..50d9da1b 100644 --- a/scripts/update_psrcat.py +++ b/scripts/update_psrcat.py @@ -27,7 +27,7 @@ holo.HolographySource.name.regexp("^[BJ][0-9]{4}\+[0-9]*$") ) ] -print("Found {:d} pulsars in database.".format(len(pulsars))) +print(f"Found {len(pulsars):d} pulsars in database.") # Query ATNF catalog flux_fields = [ @@ -69,18 +69,14 @@ name = psr break if name is None: - print( - "Failed to match ATNF entry {} to queried database pulsars.".format( - alt_names - ) - ) + print(f"Failed to match ATNF entry {alt_names} to queried database pulsars.") continue # Create a new catalog entry if name in FluxCatalog: - print("{} already in catalog. Skipping.".format(name)) - print("Alt names: {}".format(alt_names)) - print("Found: {}".format(FluxCatalog[name].name)) + print(f"{name} already in catalog. Skipping.") + print(f"Alt names: {alt_names}") + print(f"Found: {FluxCatalog[name].name}") continue # Add flux measurements @@ -106,5 +102,5 @@ entry.fit_model() # Dump to file -print("Saving catalog to: %s" % CATALOG_NAME) +print("Saving catalog to: " + CATALOG_NAME) FluxCatalog.dump(CATALOG_NAME) diff --git a/tests/test_andata.py b/tests/test_andata.py index f8343b3a..47250769 100644 --- a/tests/test_andata.py +++ b/tests/test_andata.py @@ -181,8 +181,8 @@ def getbase(a): b = a.base if b is None: return a - else: - return getbase(b) + + return getbase(b) vis = np.arange(60) vis.shape = (3, 2, 10) diff --git a/tests/test_andata_archive2.py b/tests/test_andata_archive2.py index 541e8108..975882c2 100644 --- a/tests/test_andata_archive2.py +++ b/tests/test_andata_archive2.py @@ -6,10 +6,9 @@ import h5py from ch_util import andata -from caput.memh5 import MemGroup import data_paths -# archive_acq_dir = ("/scratch/k/krs/jrs65/chime_archive/20140913T055455Z_blanchard_corr/") +# archive_acq_dir="/scratch/k/krs/jrs65/chime_archive/20140913T055455Z_blanchard_corr/" # archive_acq_fname_list = [] # fmt_corr = re.compile("([0-9]{8})_([0-9]{4}).h5") # for f in os.listdir(archive_acq_dir): @@ -279,7 +278,7 @@ def test_prod_sel_fancy(self): def _resolve_prod_input_sel(prod_sel, prod_map, input_sel, input_map): """Legacy code pasted here for regression testing.""" - if (not prod_sel is None) and (not input_sel is None): + if (prod_sel is not None) and (input_sel is not None): # This should never happen due to previouse checks. raise ValueError("*input_sel* and *prod_sel* both specified.") @@ -296,7 +295,7 @@ def _resolve_prod_input_sel(prod_sel, prod_map, input_sel, input_map): input_sel.append(p0) input_sel.append(p1) # ensure_1D here deals with h5py issue #425. - input_sel = andata._ensure_1D_selection(sorted(list(set(input_sel)))) + input_sel = andata._ensure_1D_selection(sorted(set(input_sel))) else: input_sel = andata._ensure_1D_selection(input_sel) inputs = list(np.arange(len(input_map), dtype=int)[input_sel]) diff --git a/tests/test_andata_archive3.py b/tests/test_andata_archive3.py index 00567054..7945661f 100644 --- a/tests/test_andata_archive3.py +++ b/tests/test_andata_archive3.py @@ -7,10 +7,8 @@ import glob import numpy as np -import h5py from ch_util import andata -from caput.memh5 import MemGroup import data_paths tempdir = tempfile.mkdtemp() @@ -70,10 +68,10 @@ def test_stack_sel(self): def test_no_prod_input_sel(self): """Test that you can't use input/prod sel on stacked data.""" with self.assertRaises(ValueError): - ad = andata.CorrData.from_acq_h5(self.file_list, input_sel=[0, 15]) + andata.CorrData.from_acq_h5(self.file_list, input_sel=[0, 15]) with self.assertRaises(ValueError): - ad = andata.CorrData.from_acq_h5(self.file_list, prod_sel=[0, 15]) + andata.CorrData.from_acq_h5(self.file_list, prod_sel=[0, 15]) if __name__ == "__main__": diff --git a/tests/test_andata_dist.py b/tests/test_andata_dist.py index 43c346c3..4f781d41 100644 --- a/tests/test_andata_dist.py +++ b/tests/test_andata_dist.py @@ -13,7 +13,9 @@ comm = MPI.COMM_WORLD -# fnames = glob.glob('/scratch/k/krs/jrs65/chime_archive/20140916T173334Z_blanchard_corr/000[0-1]*_0000.h5') +# fnames = glob.glob( +# '/scratch/k/krs/jrs65/chime_archive/20140916T173334Z_blanchard_corr/000[0-1]*_0000.h5' +# ) # All the test data file names. # Test data files have 32 frequencies, 136 products, and 31, 17 times.