From 96ea282d0dedbb1407dcc1a7d3199366bd74b0fc Mon Sep 17 00:00:00 2001 From: "Sinden, David" Date: Fri, 27 Sep 2024 17:03:53 +0200 Subject: [PATCH] update ksource --- kwave/ksource.py | 49 +++++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/kwave/ksource.py b/kwave/ksource.py index e5fdda5f..8a85c9be 100644 --- a/kwave/ksource.py +++ b/kwave/ksource.py @@ -248,9 +248,11 @@ def validate(self, kgrid: kWaveGrid) -> None: # if more than one time series is given, check the number of time # series given matches the number of source elements - if (self.flag.source_ux and np.size(self.ux)[0] != np.size(u_unique) or \ - self.flag.source_uy and np.size(self.uy)[0] != np.size(u_unique) or \ - self.flag.source_uz and np.size(self.uz)[0] != np.size(u_unique)): + nonzero_labels: int = np.size(np.nonzero(u_unique)) + # print(u_unique, np.size(np.nonzero(u_unique)), np.size(self.ux), np.shape(self.ux), np.size(u_unique) ) + if (self.flag_ux > 0 and np.shape(self.ux)[0] != nonzero_labels or \ + self.flag_uy > 0 and np.shape(self.uy)[0] != nonzero_labels or \ + self.flag_uz > 0 and np.shape(self.uz)[0] != nonzero_labels): raise ValueError( "The number of time series in source.ux (etc) " "must match the number of labelled source elements in source.u_mask.", np.size(self.ux)[0], np.size(u_unique) @@ -277,18 +279,18 @@ def validate(self, kgrid: kWaveGrid) -> None: # set source flgs to the length of the sources, this allows the # inputs to be defined independently and be of any length - if self.sxx is not None and np.size(self.sxx) >= kgrid.Nt: - logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt," " remaining time points will not be used.") - if self.syy is not None and np.size(self.syy) >= kgrid.Nt: - logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt," " remaining time points will not be used.") - if self.szz is not None and np.size(self.szz) >= kgrid.Nt: - logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt," " remaining time points will not be used.") - if self.sxy is not None and np.size(self.sxy) >= kgrid.Nt: - logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt," " remaining time points will not be used.") - if self.sxz is not None and np.size(self.sxz) >= kgrid.Nt: - logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt," " remaining time points will not be used.") - if self.syz is not None and np.size(self.syz) >= kgrid.Nt: - logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt," " remaining time points will not be used.") + if self.sxx is not None and np.max(np.shape(self.sxx)) > kgrid.Nt: + logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.sxx)))) + if self.syy is not None and np.max(np.shape(self.syy)) > kgrid.Nt: + logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.syy)))) + if self.szz is not None and np.max(np.shape(self.szz)) > kgrid.Nt: + logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.szz)))) + if self.sxy is not None and np.max(np.shape(self.sxy)) > kgrid.Nt: + logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.sxy)))) + if self.sxz is not None and np.max(np.shape(self.sxz)) > kgrid.Nt: + logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.sxz)))) + if self.syz is not None and np.max(np.shape(self.syz)) > kgrid.Nt: + logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.syz)))) # create an indexing variable corresponding to the location of all the source elements # raise NotImplementedError @@ -298,6 +300,7 @@ def validate(self, kgrid: kWaveGrid) -> None: # create a second indexing variable if np.size(s_unique) <= 2 and np.sum(s_unique) == 1: + s_mask_sum = np.array(self.s_mask).sum() # if more than one time series is given, check the number of time series given matches the number of source elements @@ -317,19 +320,19 @@ def validate(self, kgrid: kWaveGrid) -> None: raise ValueError("The number of time series in source.sxx (etc) must match the number of source elements in source.s_mask.") else: - # check the source labels are monotonic, and start from 1 - if np.sum(s_unique[1:-1] - s_unique[0:-2]) != (np.size(s_unique) - 1) or (not (s_unique == 0).any()): + # check the source labels are monotonic, and start from 0 + if np.sum(s_unique[1:-1] - s_unique[0:-2]) != (np.size(s_unique) - 2) or (not (s_unique == 0).any()): raise ValueError("If using a labelled source.s_mask, the source labels must be monotonically increasing and start from 0.") numel_s_unique: int = np.size(s_unique) - 1 # if more than one time series is given, check the number of time series given matches the number of source elements - if ((self.sxx and np.shape(self.sxx)[0] != numel_s_unique) or - (self.syy and np.shape(self.syy)[0] != numel_s_unique) or - (self.szz and np.shape(self.szz)[0] != numel_s_unique) or - (self.sxy and np.shape(self.sxy)[0] != numel_s_unique) or - (self.sxz and np.shape(self.sxz)[0] != numel_s_unique) or - (self.syz and np.shape(self.syz)[0] != numel_s_unique)): + if ((self.sxx is not None and np.shape(self.sxx)[0] != numel_s_unique) or + (self.syy is not None and np.shape(self.syy)[0] != numel_s_unique) or + (self.szz is not None and np.shape(self.szz)[0] != numel_s_unique) or + (self.sxy is not None and np.shape(self.sxy)[0] != numel_s_unique) or + (self.sxz is not None and np.shape(self.sxz)[0] != numel_s_unique) or + (self.syz is not None and np.shape(self.syz)[0] != numel_s_unique)): raise ValueError("The number of time series in source.sxx (etc) must match the number of labelled source elements in source.u_mask.") @property