Skip to content

Commit

Permalink
Merge pull request #3640 from alejoe91/support_numpy_2
Browse files Browse the repository at this point in the history
Support numpy 2.0
  • Loading branch information
alejoe91 authored Jan 24, 2025
2 parents ae19c2a + 6cc73f5 commit bed6308
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/full-test-with-codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'
- name: Get ephy_testing_data current head hash
# the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git
id: vars
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
description = "Python toolkit for analysis, visualization, and comparison of spike sorting output"
readme = "README.md"
requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo
requires-python = ">=3.9,<3.13"
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: MIT License",
Expand All @@ -20,11 +20,11 @@ classifiers = [


dependencies = [
"numpy>=1.20, <2.0", # 1.20 np.ptp, 1.26 might be necessary for avoiding pickling errors when numpy >2.0
"numpy>=1.20",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=2.18,<3",
"neo>=0.13.0",
"neo>=0.14.0",
"probeinterface>=0.2.23",
"packaging",
]
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def allocate_waveforms_buffers(
Dictionary to "construct" array in workers process (memmap file or sharemem)
"""

nsamples = nbefore + nafter
n_samples = nbefore + nafter

dtype = np.dtype(dtype)
if mode == "shared_memory":
Expand All @@ -187,11 +187,11 @@ def allocate_waveforms_buffers(
num_chans = recording.get_num_channels()
else:
num_chans = np.sum(sparsity_mask[unit_ind, :])
shape = (n_spikes, nsamples, num_chans)
shape = (int(n_spikes), int(n_samples), int(num_chans))

if mode == "memmap":
filename = str(folder / f"waveforms_{unit_id}.npy")
arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape)
arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype.str, shape=shape)
waveforms_by_units[unit_id] = arr
arrays_info[unit_id] = filename
elif mode == "shared_memory":
Expand Down Expand Up @@ -476,7 +476,7 @@ def extract_waveforms_to_single_buffer(
Optionally return in case of shared_memory if copy=False.
Dictionary to "construct" array in workers process (memmap file or sharemem info)
"""
nsamples = nbefore + nafter
n_samples = nbefore + nafter

dtype = np.dtype(dtype)
if mode == "shared_memory":
Expand All @@ -489,7 +489,7 @@ def extract_waveforms_to_single_buffer(
num_chans = recording.get_num_channels()
else:
num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int
shape = (num_spikes, nsamples, num_chans)
shape = (int(num_spikes), int(n_samples), int(num_chans))

if mode == "memmap":
all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape)
Expand Down

0 comments on commit bed6308

Please sign in to comment.