Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve phase shift memory efficiency #2946

Merged
merged 8 commits into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions src/spikeinterface/preprocessing/phase_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
add_zeros=True,
window_on_margin=True,
)

traces_shift = apply_fshift_sam(traces_chunk, self.sample_shifts[channel_indices], axis=0)
# traces_shift = apply_fshift_ibl(traces_chunk, self.sample_shifts, axis=0)
traces_shift = apply_frequency_shift(traces_chunk, self.sample_shifts[channel_indices], axis=0)

traces_shift = traces_shift[left_margin:-right_margin, :]
if self.tmp_dtype is not None:
Expand All @@ -116,28 +114,66 @@ def get_traces(self, start_frame, end_frame, channel_indices):
phase_shift = define_function_from_class(source_class=PhaseShiftRecording, name="phase_shift")


def apply_fshift_sam(sig, sample_shifts, axis=0):
def apply_frequency_shift(signal, shift_samples, axis=0):
"""
Apply the shift on a traces buffer.
Apply frequency shift to a signal buffer. This allow for shifting that are sub-sample accurate.

Parameters
----------
signal : ndarray
Input signal array to be shifted.
shift_samples : ndarray
Array of sample shifts for each channel. Phase shifts are in units of 1/sampling_rate.
axis : int, optional
Axis along which to perform the shift. Currently, only axis=0 is supported.

Returns
-------
shifted_signal : ndarray
Signal array with the applied frequency shifts.

Notes
-----
The function works by transforming the signal to the frequency domain using the real FFT (rFFT),
applying phase shifts, and then transforming back to the time domain using the inverse real FFT (irFFT).
The phase shifts are calculated based on the frequency grid obtained from the FFT.

The key steps are:
1. Compute the rFFT of the input signal.
2. Calculate the frequency grid and use it to compute the phase shifts.
3. Apply the phase shifts in the frequency domain.
4. Perform the inverse rFFT to obtain the shifted signal in the time domain.

This method leverages the properties of the Fourier transform, where a phase shift in the frequency domain
corresponds to a time shift in the time domain.
"""
n = sig.shape[axis]
sig_f = np.fft.rfft(sig, axis=axis)
if n % 2 == 0:
# n is even sig_f[-1] is nyquist and so pi
omega = np.linspace(0, np.pi, sig_f.shape[axis])
else:
# n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!!
omega = np.linspace(0, np.pi * (n - 1) / n, sig_f.shape[axis])
# broadcast omega and sample_shifts depend the axis
import scipy.fft

signal_length = signal.shape[axis]
num_channels = shift_samples.size
fourier_signal_size = signal_length // 2 + 1

frequency_domain_signal = scipy.fft.rfft(signal, n=signal_length, axis=axis, overwrite_x=True)
fourier_signal_size = frequency_domain_signal.shape[0]

if axis == 0:
shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :]
frequency_grid = np.empty(shape=(fourier_signal_size, num_channels))
# Note that np.fft.rfttfreq handles both even and odd signal lengths
frequency_grid[:, :] = 2 * np.pi * np.fft.rfftfreq(signal_length)[:, np.newaxis]
shifts = np.multiply(frequency_grid, shift_samples[np.newaxis, :], out=frequency_grid)
else:
shifts = omega[np.newaxis, :] * sample_shifts[:, np.newaxis]
sig_shift = np.fft.irfft(sig_f * np.exp(-1j * shifts), n=n, axis=axis)
return sig_shift
raise NotImplementedError("Axis != 0 is not implemented yet")

# Rotate the signal in the frequency domain
rotations = np.exp(-1j * shifts)
phase_shifted_signal = np.multiply(frequency_domain_signal, rotations, out=rotations)

# Inverse FFT to get the translated signal
shifted_signal = scipy.fft.irfft(phase_shifted_signal, n=signal_length, axis=axis, overwrite_x=True)
return shifted_signal


apply_fshift = apply_fshift_sam
apply_fshift = apply_frequency_shift


def apply_fshift_ibl(w, s, axis=0, ns=None):
Expand Down
Loading