Skip to content

Commit

Permalink
Added option to correct for template itself in sd_ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
DradeAW committed Nov 20, 2023
1 parent f18eb9c commit 848ee93
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
16 changes: 15 additions & 1 deletion src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@ def compute_sd_ratio(
wvf_extractor: WaveformExtractor,
censored_period_ms: float = 4.0,
correct_for_drift: bool = True,
correct_for_template_itself: bool = True,
unit_ids=None,
**kwargs,
):
Expand All @@ -1387,11 +1388,13 @@ def compute_sd_ratio(
The censored period in milliseconds. This is to remove any potential bursts that could affect the SD.
correct_for_drift: bool, default: True
If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift.
correct_for_template_itself: bool, default: True
If true, will take into account that the template itself impacts the standard deviation of the noise,
and will make a rough estimation of what that impact is (and remove it).
unit_ids : list or None, default: None
The list of unit ids to compute this metric. If None, all units are used.
**kwargs:
Keyword arguments for computing spike amplitudes and extremum channel.
TODO: Possibly remove spikes when computing noise?
TODO: Take jitter into account.
Returns
Expand Down Expand Up @@ -1428,6 +1431,7 @@ def compute_sd_ratio(
wvf_extractor.recording, return_scaled=amplitudes_ext._params["return_scaled"], method="std"
)
best_channels = get_template_extremum_channel(wvf_extractor, outputs="index", **kwargs)
n_spikes = wvf_extractor.sorting.count_num_spikes_per_unit()

sd_ratio = {}
for unit_id in unit_ids:
Expand All @@ -1449,6 +1453,16 @@ def compute_sd_ratio(
best_channel = best_channels[unit_id]
std_noise = noise_levels[best_channel]

if correct_for_template_itself:
template = wvf_extractor.get_template(unit_id, force_dense=True)[:, best_channel]

# Computing the variance of a trace that is all 0 and n_spikes non-overlapping template.
# TODO: Take into account that templates for different segments might differ.
p = wvf_extractor.nsamples * n_spikes[unit_id] / wvf_extractor.get_total_samples()
total_variance = p * np.mean(template**2) - p**2 * np.mean(template)

std_noise = np.sqrt(std_noise**2 - total_variance)

sd_ratio[unit_id] = unit_std / std_noise

return sd_ratio
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _simulated_data():


def _waveform_extractor_simple():
recording, sorting = toy_example(duration=80, seed=10)
recording, sorting = toy_example(duration=80, seed=10, firing_rate=6.0)
recording = recording.save(folder=cache_folder / "rec1")
sorting = sorting.save(folder=cache_folder / "sort1")
folder = cache_folder / "waveform_folder1"
Expand Down Expand Up @@ -384,7 +384,7 @@ def test_calculate_sd_ratio(waveform_extractor_simple):
sd_ratio = compute_sd_ratio(waveform_extractor_simple)

assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids)
assert np.allclose(np.array(list(sd_ratio.values())), 1, atol=0.5, rtol=0)
assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0)


if __name__ == "__main__":
Expand Down

0 comments on commit 848ee93

Please sign in to comment.