From 5f549a5d485df3ff85b47941be96af06a1e7d133 Mon Sep 17 00:00:00 2001 From: cortner Date: Tue, 30 Jul 2024 18:01:30 -0700 Subject: [PATCH 1/2] allow elements passed as tuple --- src/defaults.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/defaults.jl b/src/defaults.jl index c43b2da..f9550c5 100644 --- a/src/defaults.jl +++ b/src/defaults.jl @@ -141,8 +141,12 @@ function _get_r0(kwargs, z1, z2) error("Unable to determine r0($z1, $z2) from the arguments provided.") end +function _get_elements(kwargs) + return [ kwargs[:elements]... ] +end + function _get_all_r0(kwargs) - elements = kwargs[:elements] + elements = _get_elements(kwargs) r0 = Dict( [ (s1, s2) => _get_r0(kwargs, s1, s2) for s1 in elements, s2 in elements]... ) end @@ -164,14 +168,14 @@ function _get_all_rcut(kwargs; _rcut = kwargs[:rcut]) if _rcut isa Number return _rcut end - elements = kwargs[:elements] + elements = _get_elements(kwargs) rcut = Dict( [ (s1, s2) => _get_rcut(kwargs, s1, s2; _rcut = _rcut) for s1 in elements, s2 in elements]... ) return rcut end function _transform(kwargs; transform = kwargs[:transform]) - elements = kwargs[:elements] + elements = _get_elements(kwargs) if transform isa Tuple if transform[1] == :agnesi @@ -256,7 +260,7 @@ end function _pair_basis(kwargs) rbasis = kwargs[:pair_basis] - elements = kwargs[:elements] + elements = _get_elements(kwargs) #elements has to be sorted becuase PolyPairBasis (see end of function) assumes sorted. if kwargs[:variable_cutoffs] elements = [chemical_symbol(z) for z in JuLIP.Potentials.ZList(elements, static=true).list] @@ -313,7 +317,7 @@ end function mb_ace_basis(kwargs) - elements = kwargs[:elements] + elements = _get_elements(kwargs) cor_order = _get_order(kwargs) Deg, maxdeg, maxn = _get_degrees(kwargs) rbasis = _radial_basis(kwargs) From 2f3335f9fb26c1e2120103b9253b5a80b4218a8c Mon Sep 17 00:00:00 2001 From: cortner Date: Tue, 30 Jul 2024 18:37:16 -0700 Subject: [PATCH 2/2] add a test for #18 --- test/runtests.jl | 1 + test/test_bugs.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 test/test_bugs.jl diff --git a/test/runtests.jl b/test/runtests.jl index ccb282f..c888edc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,5 @@ using Test @testset "Purify-single" begin include("test_purify.jl"); end @testset "Purify-multi" begin include("test_purify_multi.jl"); end @testset "acemodel" begin include("test_acemodel.jl"); end + @testset "Weird Bugs" begin include("test_bugs.jl"); end end diff --git a/test/test_bugs.jl b/test/test_bugs.jl new file mode 100644 index 0000000..d738e5d --- /dev/null +++ b/test/test_bugs.jl @@ -0,0 +1,14 @@ + +using ACE1x, Test + +model1 = acemodel(elements = [:Si,], + order = 3, + totaldegree = 6, + rcut = 5.5, ) + +model2 = acemodel(elements = (:Si,), + order = 3, + totaldegree = 6, + rcut = 5.5, ) + +@test model1.basis == model2.basis \ No newline at end of file