Skip to content

Commit

Permalink
Merge pull request #1625 from alejoe91/final-2.0-fixes
Browse files Browse the repository at this point in the history
Fix failing tests due to Numpy 2.0
  • Loading branch information
samuelgarcia authored Jan 17, 2025
2 parents dbe2e95 + 70d040f commit c034591
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/io-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ['3.9', '3.12']
python-version: ['3.9', '3.13']
defaults:
# by default run in bash mode (required for conda usage)
run:
Expand Down
3 changes: 0 additions & 3 deletions environment_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,3 @@ channels:
dependencies:
- datalad
- pip
# temporary have this here for IO testing while we decide how to deal with
# external packages not 2.0 ready
- numpy=1.26.4
2 changes: 1 addition & 1 deletion neo/io/klustakwikio.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _load_spike_times(self, fetfilename):
names.append("spike_time")

# Load into recarray
data = np.recfromtxt(fetfilename, names=names, skip_header=1, delimiter=" ")
data = np.genfromtxt(fetfilename, names=names, skip_header=1, delimiter=" ")

# get features
features = np.array([data[f"fet{n}"] for n in range(nbFeatures)])
Expand Down
7 changes: 4 additions & 3 deletions neo/rawio/blackrockrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def __match_nsx_and_nev_segment_ids(self, nsx_nb):

# Show warning if spikes do not fit any segment (+- 1 sampling 'tick')
# Spike should belong to segment before
mask_outside = (ev_ids == i) & (data["timestamp"] < int(seg["timestamp"]) - nsx_offset - nsx_period)
mask_outside = (ev_ids == i) & (data["timestamp"] < int(seg["timestamp"]) - int(nsx_offset) - int(nsx_period))

if len(data[mask_outside]) > 0:
warnings.warn(f"Spikes outside any segment. Detected on segment #{i}")
Expand Down Expand Up @@ -1995,6 +1995,7 @@ def __get_nsx_param_variant_a(self, nsx_nb):
else:
units = "uV"


nsx_parameters = {
"nb_data_points": int(
(self.__get_file_size(filename) - bytes_in_headers)
Expand All @@ -2003,8 +2004,8 @@ def __get_nsx_param_variant_a(self, nsx_nb):
),
"labels": labels,
"units": np.array([units] * self.__nsx_basic_header[nsx_nb]["channel_count"]),
"min_analog_val": -1 * np.array(dig_factor),
"max_analog_val": np.array(dig_factor),
"min_analog_val": -1 * np.array(dig_factor, dtype="float"),
"max_analog_val": np.array(dig_factor, dtype="float"),
"min_digital_val": np.array([-1000] * self.__nsx_basic_header[nsx_nb]["channel_count"]),
"max_digital_val": np.array([1000] * self.__nsx_basic_header[nsx_nb]["channel_count"]),
"timestamp_resolution": 30000,
Expand Down
18 changes: 9 additions & 9 deletions neo/test/iotest/test_asciisignalio.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ def test_skiprows(self):
self.assertEqual(signal.units, pq.V)
assert_array_equal(signal.times, [0.0, 1.0, 2.0, 3.0] * pq.s)
assert_array_equal(signal.times.magnitude, [0.0, 1.0, 2.0, 3.0])
assert_array_equal(signal[0].magnitude, -64.8)
assert_array_equal(signal[1].magnitude, -64.6)
assert_array_equal(signal[2].magnitude, -64.3)
assert_array_equal(signal[3].magnitude, -66)
assert_array_almost_equal(signal[0].magnitude, -64.8, decimal=5)
assert_array_almost_equal(signal[1].magnitude, -64.6, decimal=5)
assert_array_almost_equal(signal[2].magnitude, -64.3, decimal=5)
assert_array_almost_equal(signal[3].magnitude, -66, decimal=5)
assert_array_almost_equal(np.asarray(signal).flatten(), np.array([-64.8, -64.6, -64.3, -66]), decimal=5)

os.remove(filename)
Expand All @@ -195,11 +195,11 @@ def test_usecols(self):
self.assertEqual(signal.units, pq.V)
assert_array_equal(signal.times, [0.0, 1.0, 2.0, 3.0, 4.0] * pq.s)
assert_array_equal(signal.times.magnitude, [0.0, 1.0, 2.0, 3.0, 4.0])
assert_array_equal(signal[0].magnitude, 0.5)
assert_array_equal(signal[1].magnitude, 0.6)
assert_array_equal(signal[2].magnitude, 0.7)
assert_array_equal(signal[3].magnitude, 0.8)
assert_array_equal(signal[4].magnitude, 1.4)
assert_array_almost_equal(signal[0].magnitude, 0.5, decimal=5)
assert_array_almost_equal(signal[1].magnitude, 0.6, decimal=5)
assert_array_almost_equal(signal[2].magnitude, 0.7, decimal=5)
assert_array_almost_equal(signal[3].magnitude, 0.8, decimal=5)
assert_array_almost_equal(signal[4].magnitude, 1.4, decimal=5)
assert_array_almost_equal(np.asarray(signal).flatten(), np.array([0.5, 0.6, 0.7, 0.8, 1.4]), decimal=5)

os.remove(filename)
Expand Down
12 changes: 6 additions & 6 deletions neo/test/iotest/test_axographio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from neo.test.iotest.common_io_test import BaseTestIO

import numpy as np
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_almost_equal
import quantities as pq
from neo.test.rawiotest.test_axographrawio import TestAxographRawIO

Expand All @@ -35,9 +35,9 @@ def test_version_1(self):
target = np.array([[-5.5078130], [-3.1171880], [+1.6640626], [+1.6640626], [+4.0546880]], dtype=np.float32)
assert_equal(arr, target)

assert_equal(sig.t_start, 0.0005000000237487257 * pq.s)
assert_almost_equal(sig.t_start, 0.0005000000237487257 * pq.s, decimal=9)

assert_equal(sig.sampling_period, 0.0005000010132789612 * pq.s)
assert_almost_equal(sig.sampling_period, 0.0005000010132789612 * pq.s, decimal=9)

def test_version_2(self):
"""Test reading a version 2 AxoGraph file"""
Expand Down Expand Up @@ -87,9 +87,9 @@ def test_version_2(self):
target = np.array([[0.3125], [9.6875], [9.6875], [9.6875], [9.3750]], dtype=np.float32)
assert_equal(arr, target)

assert_equal(sig.t_start, 0.00009999999747378752 * pq.s)
assert_almost_equal(sig.t_start, 0.00009999999747378752 * pq.s, decimal=9)

assert_equal(sig.sampling_period, 0.00009999999747378750 * pq.s)
assert_almost_equal(sig.sampling_period, 0.00009999999747378750 * pq.s, decimal=9)

def test_version_5(self):
"""Test reading a version 5 AxoGraph file"""
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_file_written_by_axographio_package_without_linearsequence(self):

assert_equal(sig.t_start, 0 * pq.s)

assert_equal(sig.sampling_period, 0.009999999999999787 * pq.s)
assert_almost_equal(sig.sampling_period, 0.009999999999999787 * pq.s, decimal=9)

def test_file_with_corrupt_comment(self):
"""Test reading a file with a corrupt comment"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ docs = [
"nixio",
"pynwb",
"igor2",
"numpy<2.0" # https://github.com/NeuralEnsemble/python-neo/pull/1612
"numpy>=2.0"
]

dev = [
Expand Down

0 comments on commit c034591

Please sign in to comment.