From 5525a0c2c4ee463dd06f9d94565f04020c9a91c9 Mon Sep 17 00:00:00 2001 From: Aakash Ashok Naik <91958822+naik-aakash@users.noreply.github.com> Date: Mon, 18 Nov 2024 14:32:58 +0100 Subject: [PATCH] Fix incorrect comparison logic and update tests (#4181) * fix erroneous comp logic and simply complexcity based on suggestions * fix tests that correctly evaluate the checks * pre-commit auto-fixes * seperate has_good_quality_check_occupied_bands tests * address review comment * pre-commit auto-fixes * address review comment > change numeric values to scientific notation * pre-commit auto-fixes * add static method to get sub_array (helps in testing) * update tests * remove duplicate accidental test lines * remove duplicate assert * update test as per review suggestion * address review comments * address review comments2 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/pymatgen/io/lobster/outputs.py | 34 +++------ tests/io/lobster/test_outputs.py | 117 +++++++++++++++++++++++------ 2 files changed, 107 insertions(+), 44 deletions(-) diff --git a/src/pymatgen/io/lobster/outputs.py b/src/pymatgen/io/lobster/outputs.py index bc5f8e17800..300e2f68c90 100644 --- a/src/pymatgen/io/lobster/outputs.py +++ b/src/pymatgen/io/lobster/outputs.py @@ -1710,29 +1710,17 @@ def has_good_quality_check_occupied_bands( Returns: bool: True if the quality of the projection is good. """ - for matrix in self.band_overlaps_dict[Spin.up]["matrices"]: - for iband1, band1 in enumerate(matrix): - for iband2, band2 in enumerate(band1): - if iband1 < number_occ_bands_spin_up and iband2 < number_occ_bands_spin_up: - if iband1 == iband2: - if abs(band2 - 1.0).all() > limit_deviation: - return False - elif band2.all() > limit_deviation: - return False - - if spin_polarized: - for matrix in self.band_overlaps_dict[Spin.down]["matrices"]: - for iband1, band1 in enumerate(matrix): - for iband2, band2 in enumerate(band1): - if number_occ_bands_spin_down is None: - raise ValueError("number_occ_bands_spin_down has to be specified") - - if iband1 < number_occ_bands_spin_down and iband2 < number_occ_bands_spin_down: - if iband1 == iband2: - if abs(band2 - 1.0).all() > limit_deviation: - return False - elif band2.all() > limit_deviation: - return False + if spin_polarized and number_occ_bands_spin_down is None: + raise ValueError("number_occ_bands_spin_down has to be specified") + + for spin in (Spin.up, Spin.down) if spin_polarized else (Spin.up,): + num_occ_bands = number_occ_bands_spin_up if spin is Spin.up else number_occ_bands_spin_down + + for overlap_matrix in self.band_overlaps_dict[spin]["matrices"]: + sub_array = np.asarray(overlap_matrix)[:num_occ_bands, :num_occ_bands] + + if not np.allclose(sub_array, np.identity(num_occ_bands), atol=limit_deviation, rtol=0): + return False return True diff --git a/tests/io/lobster/test_outputs.py b/tests/io/lobster/test_outputs.py index 30ea62e687f..4dac8c4a01b 100644 --- a/tests/io/lobster/test_outputs.py +++ b/tests/io/lobster/test_outputs.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import json import os from unittest import TestCase @@ -1481,7 +1482,7 @@ def test_get_bandstructure(self): class TestBandoverlaps(TestCase): def setUp(self): - # test spin-polarized calc and non spinpolarized calc + # test spin-polarized calc and non spin-polarized calc self.band_overlaps1 = Bandoverlaps(f"{TEST_DIR}/bandOverlaps.lobster.1") self.band_overlaps2 = Bandoverlaps(f"{TEST_DIR}/bandOverlaps.lobster.2") @@ -1515,9 +1516,18 @@ def test_attributes(self): assert self.band_overlaps2.max_deviation[-1] == approx(1.48451e-05) assert self.band_overlaps2_new.max_deviation[-1] == approx(0.45154) - def test_has_good_quality(self): + def test_has_good_quality_maxDeviation(self): assert not self.band_overlaps1.has_good_quality_maxDeviation(limit_maxDeviation=0.1) assert not self.band_overlaps1_new.has_good_quality_maxDeviation(limit_maxDeviation=0.1) + + assert self.band_overlaps1.has_good_quality_maxDeviation(limit_maxDeviation=100) + assert self.band_overlaps1_new.has_good_quality_maxDeviation(limit_maxDeviation=100) + assert self.band_overlaps2.has_good_quality_maxDeviation() + assert not self.band_overlaps2_new.has_good_quality_maxDeviation() + assert not self.band_overlaps2.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001) + assert not self.band_overlaps2_new.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001) + + def test_has_good_quality_check_occupied_bands(self): assert not self.band_overlaps1.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=9, number_occ_bands_spin_down=5, @@ -1545,65 +1555,58 @@ def test_has_good_quality(self): assert not self.band_overlaps1.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=1, number_occ_bands_spin_down=1, - limit_deviation=0.000001, + limit_deviation=1e-6, spin_polarized=True, ) assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=1, number_occ_bands_spin_down=1, - limit_deviation=0.000001, + limit_deviation=1e-6, spin_polarized=True, ) assert not self.band_overlaps1.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=1, number_occ_bands_spin_down=0, - limit_deviation=0.000001, + limit_deviation=1e-6, spin_polarized=True, ) assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=1, number_occ_bands_spin_down=0, - limit_deviation=0.000001, + limit_deviation=1e-6, spin_polarized=True, ) assert not self.band_overlaps1.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=0, number_occ_bands_spin_down=1, - limit_deviation=0.000001, + limit_deviation=1e-6, spin_polarized=True, ) assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=0, number_occ_bands_spin_down=1, - limit_deviation=0.000001, + limit_deviation=1e-6, spin_polarized=True, ) assert not self.band_overlaps1.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=4, number_occ_bands_spin_down=4, - limit_deviation=0.001, + limit_deviation=1e-3, spin_polarized=True, ) assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=4, number_occ_bands_spin_down=4, - limit_deviation=0.001, + limit_deviation=1e-3, spin_polarized=True, ) - - assert self.band_overlaps1.has_good_quality_maxDeviation(limit_maxDeviation=100) - assert self.band_overlaps1_new.has_good_quality_maxDeviation(limit_maxDeviation=100) - assert self.band_overlaps2.has_good_quality_maxDeviation() - assert not self.band_overlaps2_new.has_good_quality_maxDeviation() - assert not self.band_overlaps2.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001) - assert not self.band_overlaps2_new.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001) assert not self.band_overlaps2.has_good_quality_check_occupied_bands( - number_occ_bands_spin_up=10, limit_deviation=0.0000001 + number_occ_bands_spin_up=10, limit_deviation=1e-7 ) assert not self.band_overlaps2_new.has_good_quality_check_occupied_bands( - number_occ_bands_spin_up=10, limit_deviation=0.0000001 + number_occ_bands_spin_up=10, limit_deviation=1e-7 ) - assert not self.band_overlaps2.has_good_quality_check_occupied_bands( + assert self.band_overlaps2.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=1, limit_deviation=0.1 ) @@ -1614,7 +1617,7 @@ def test_has_good_quality(self): number_occ_bands_spin_up=1, limit_deviation=1e-8 ) assert self.band_overlaps2.has_good_quality_check_occupied_bands(number_occ_bands_spin_up=10, limit_deviation=1) - assert not self.band_overlaps2_new.has_good_quality_check_occupied_bands( + assert self.band_overlaps2_new.has_good_quality_check_occupied_bands( number_occ_bands_spin_up=2, limit_deviation=0.1 ) assert self.band_overlaps2.has_good_quality_check_occupied_bands(number_occ_bands_spin_up=1, limit_deviation=1) @@ -1622,6 +1625,78 @@ def test_has_good_quality(self): number_occ_bands_spin_up=1, limit_deviation=2 ) + def test_has_good_quality_check_occupied_bands_patched(self): + """Test with patched data.""" + + limit_deviation = 0.1 + + rng = np.random.default_rng(42) # set seed for reproducibility + + band_overlaps = copy.deepcopy(self.band_overlaps1_new) + + number_occ_bands_spin_up_all = list(range(band_overlaps.band_overlaps_dict[Spin.up]["matrices"][0].shape[0])) + number_occ_bands_spin_down_all = list( + range(band_overlaps.band_overlaps_dict[Spin.down]["matrices"][0].shape[0]) + ) + + for actual_deviation in [0.05, 0.1, 0.2, 0.5, 1.0]: + for spin in (Spin.up, Spin.down): + for number_occ_bands_spin_up, number_occ_bands_spin_down in zip( + number_occ_bands_spin_up_all, number_occ_bands_spin_down_all, strict=False + ): + for i_arr, array in enumerate(band_overlaps.band_overlaps_dict[spin]["matrices"]): + number_occ_bands = number_occ_bands_spin_up if spin is Spin.up else number_occ_bands_spin_down + + shape = array.shape + assert np.all(np.array(shape) >= number_occ_bands) + assert len(shape) == 2 + assert shape[0] == shape[1] + + # Generate a noisy background array + patch_array = rng.uniform(0, 10, shape) + + # Patch the top-left sub-array (the part that would be checked) + patch_array[:number_occ_bands, :number_occ_bands] = np.identity(number_occ_bands) + rng.uniform( + 0, actual_deviation, (number_occ_bands, number_occ_bands) + ) + + band_overlaps.band_overlaps_dict[spin]["matrices"][i_arr] = patch_array + + result = band_overlaps.has_good_quality_check_occupied_bands( + number_occ_bands_spin_up=number_occ_bands_spin_up, + number_occ_bands_spin_down=number_occ_bands_spin_down, + spin_polarized=True, + limit_deviation=limit_deviation, + ) + # Assert for expected results + if ( + actual_deviation == 0.05 + and number_occ_bands_spin_up <= 7 + and number_occ_bands_spin_down <= 7 + and spin is Spin.up + or actual_deviation == 0.05 + and spin is Spin.down + or actual_deviation == 0.1 + or actual_deviation in [0.2, 0.5, 1.0] + and number_occ_bands_spin_up == 0 + and number_occ_bands_spin_down == 0 + ): + assert result + else: + assert not result + + def test_exceptions(self): + with pytest.raises(ValueError, match="number_occ_bands_spin_down has to be specified"): + self.band_overlaps1.has_good_quality_check_occupied_bands( + number_occ_bands_spin_up=4, + spin_polarized=True, + ) + with pytest.raises(ValueError, match="number_occ_bands_spin_down has to be specified"): + self.band_overlaps1_new.has_good_quality_check_occupied_bands( + number_occ_bands_spin_up=4, + spin_polarized=True, + ) + def test_msonable(self): dict_data = self.band_overlaps2_new.as_dict() bandoverlaps_from_dict = Bandoverlaps.from_dict(dict_data)