Skip to content

Commit

Permalink
update ksource
Browse files Browse the repository at this point in the history
  • Loading branch information
djps committed Sep 27, 2024
1 parent 8389bf2 commit 96ea282
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions kwave/ksource.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,11 @@ def validate(self, kgrid: kWaveGrid) -> None:

# if more than one time series is given, check the number of time
# series given matches the number of source elements
if (self.flag.source_ux and np.size(self.ux)[0] != np.size(u_unique) or \
self.flag.source_uy and np.size(self.uy)[0] != np.size(u_unique) or \
self.flag.source_uz and np.size(self.uz)[0] != np.size(u_unique)):
nonzero_labels: int = np.size(np.nonzero(u_unique))
# print(u_unique, np.size(np.nonzero(u_unique)), np.size(self.ux), np.shape(self.ux), np.size(u_unique) )
if (self.flag_ux > 0 and np.shape(self.ux)[0] != nonzero_labels or \
self.flag_uy > 0 and np.shape(self.uy)[0] != nonzero_labels or \
self.flag_uz > 0 and np.shape(self.uz)[0] != nonzero_labels):
raise ValueError(
"The number of time series in source.ux (etc) "
"must match the number of labelled source elements in source.u_mask.", np.size(self.ux)[0], np.size(u_unique)
Expand All @@ -277,18 +279,18 @@ def validate(self, kgrid: kWaveGrid) -> None:

# set source flgs to the length of the sources, this allows the
# inputs to be defined independently and be of any length
if self.sxx is not None and np.size(self.sxx) >= kgrid.Nt:
logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.syy is not None and np.size(self.syy) >= kgrid.Nt:
logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.szz is not None and np.size(self.szz) >= kgrid.Nt:
logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.sxy is not None and np.size(self.sxy) >= kgrid.Nt:
logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.sxz is not None and np.size(self.sxz) >= kgrid.Nt:
logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.syz is not None and np.size(self.syz) >= kgrid.Nt:
logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.sxx is not None and np.max(np.shape(self.sxx)) > kgrid.Nt:
logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.sxx))))
if self.syy is not None and np.max(np.shape(self.syy)) > kgrid.Nt:
logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.syy))))
if self.szz is not None and np.max(np.shape(self.szz)) > kgrid.Nt:
logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.szz))))
if self.sxy is not None and np.max(np.shape(self.sxy)) > kgrid.Nt:
logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.sxy))))
if self.sxz is not None and np.max(np.shape(self.sxz)) > kgrid.Nt:
logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.sxz))))
if self.syz is not None and np.max(np.shape(self.syz)) > kgrid.Nt:
logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt, remaining time points will not be used - " + str(np.max(np.shape(self.syz))))

# create an indexing variable corresponding to the location of all the source elements
# raise NotImplementedError
Expand All @@ -298,6 +300,7 @@ def validate(self, kgrid: kWaveGrid) -> None:

# create a second indexing variable
if np.size(s_unique) <= 2 and np.sum(s_unique) == 1:

s_mask_sum = np.array(self.s_mask).sum()

# if more than one time series is given, check the number of time series given matches the number of source elements
Expand All @@ -317,19 +320,19 @@ def validate(self, kgrid: kWaveGrid) -> None:
raise ValueError("The number of time series in source.sxx (etc) must match the number of source elements in source.s_mask.")

else:
# check the source labels are monotonic, and start from 1
if np.sum(s_unique[1:-1] - s_unique[0:-2]) != (np.size(s_unique) - 1) or (not (s_unique == 0).any()):
# check the source labels are monotonic, and start from 0
if np.sum(s_unique[1:-1] - s_unique[0:-2]) != (np.size(s_unique) - 2) or (not (s_unique == 0).any()):
raise ValueError("If using a labelled source.s_mask, the source labels must be monotonically increasing and start from 0.")

numel_s_unique: int = np.size(s_unique) - 1

# if more than one time series is given, check the number of time series given matches the number of source elements
if ((self.sxx and np.shape(self.sxx)[0] != numel_s_unique) or
(self.syy and np.shape(self.syy)[0] != numel_s_unique) or
(self.szz and np.shape(self.szz)[0] != numel_s_unique) or
(self.sxy and np.shape(self.sxy)[0] != numel_s_unique) or
(self.sxz and np.shape(self.sxz)[0] != numel_s_unique) or
(self.syz and np.shape(self.syz)[0] != numel_s_unique)):
if ((self.sxx is not None and np.shape(self.sxx)[0] != numel_s_unique) or
(self.syy is not None and np.shape(self.syy)[0] != numel_s_unique) or
(self.szz is not None and np.shape(self.szz)[0] != numel_s_unique) or
(self.sxy is not None and np.shape(self.sxy)[0] != numel_s_unique) or
(self.sxz is not None and np.shape(self.sxz)[0] != numel_s_unique) or
(self.syz is not None and np.shape(self.syz)[0] != numel_s_unique)):
raise ValueError("The number of time series in source.sxx (etc) must match the number of labelled source elements in source.u_mask.")

@property
Expand Down

0 comments on commit 96ea282

Please sign in to comment.