Skip to content

Commit

Permalink
style: apply new black
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Mar 18, 2024
1 parent 858c2d7 commit 47fb92f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
22 changes: 14 additions & 8 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,9 @@ def _calculate_delays(
# NOTE: this not obviously the right level for this, but it's the only baseclass in
# common to where it's used
def _cut_data(
self, data: np.ndarray, weight: np.ndarray,
self,
data: np.ndarray,
weight: np.ndarray,
) -> Optional[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""Apply cuts on the data and weights and returned modified versions.
Expand All @@ -587,8 +589,12 @@ def _cut_data(
The selection of times retained.
"""
ntime, nfreq = data.shape[-2:]
non_zero_time = (weight > 0).mean(axis=-1).reshape(-1, ntime).mean(axis=0) > self.time_frac
non_zero_freq = (weight > 0).mean(axis=-2).reshape(-1, nfreq).mean(axis=0) > self.freq_frac
non_zero_time = (weight > 0).mean(axis=-1).reshape(-1, ntime).mean(
axis=0
) > self.time_frac
non_zero_freq = (weight > 0).mean(axis=-2).reshape(-1, nfreq).mean(
axis=0
) > self.freq_frac

# If there are no non-zero weighted entries skip
if not non_zero_freq.any():
Expand Down Expand Up @@ -619,8 +625,8 @@ def _cut_data(
# obtain constant total power
if self.scale_freq:
dscl = (
data.std(axis=-2)[..., np.newaxis, :] /
data.std(axis=(-1, -2))[..., np.newaxis, np.newaxis]
data.std(axis=-2)[..., np.newaxis, :]
/ data.std(axis=(-1, -2))[..., np.newaxis, np.newaxis]
)
data = data * tools.invert_no_zero(dscl)

Expand Down Expand Up @@ -2026,7 +2032,7 @@ def loglike(self, s_a: np.ndarray) -> float:
ll = lndet + np.diagonal(CiX).sum().real

ll += self.alpha * s_a @ self.IW2 @ s_a
#ll += (s_a @ (self.Cdi @ s_a)).real
# ll += (s_a @ (self.Cdi @ s_a)).real
return ll

def gradient(self, s_a: np.ndarray) -> np.ndarray:
Expand All @@ -2035,7 +2041,7 @@ def gradient(self, s_a: np.ndarray) -> np.ndarray:

g = -(self._Ut.conj() * self._XC_Ut).sum(axis=0).real
g += self.alpha * self.IW2 @ s_a
#g += (self.Cdi @ s_a).real
# g += (self.Cdi @ s_a).real
return g

def hessian(self, s_a: np.ndarray) -> np.ndarray:
Expand All @@ -2053,7 +2059,7 @@ def hessian(self, s_a: np.ndarray) -> np.ndarray:
t = -(self._Wt.conj() * self._XC_Wt).sum(axis=0).real
H += np.diag(t.real)
H += self.alpha * self.IW2
#H += self.Cdi
# H += self.Cdi

return H

Expand Down
14 changes: 9 additions & 5 deletions draco/analysis/delayopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def delay_power_spectrum_maxpost(
Did the solve successfully converge.
"""
from .delay import fourier_matrix

nsamp, Nf = data.shape

if fsel is None:
Expand Down Expand Up @@ -372,10 +373,14 @@ def delay_power_spectrum_maxpost(
else:
lsi = np.log(initial_S)

optfunc = AddFunctions([
LogLikePS(X, F, Nm, nsamp, exact_hessian=True),
GaussianProcessPrior(N, alpha=5, width=3.0, kernel="gaussian", a=5.0, reg=1e-8),
])
optfunc = AddFunctions(
[
LogLikePS(X, F, Nm, nsamp, exact_hessian=True),
GaussianProcessPrior(
N, alpha=5, width=3.0, kernel="gaussian", a=5.0, reg=1e-8
),
]
)

samples = []

Expand All @@ -396,4 +401,3 @@ def _get_intermediate(xk):
# NOTE: the final sample in samples is already the final result

return samples, res.success

0 comments on commit 47fb92f

Please sign in to comment.