diff --git a/src/AMDGPUExt/update_halo.jl b/src/AMDGPUExt/update_halo.jl index b06f861..3870e00 100644 --- a/src/AMDGPUExt/update_halo.jl +++ b/src/AMDGPUExt/update_halo.jl @@ -138,7 +138,7 @@ let rocstreams = Array{AMDGPU.HIPStream}(undef, NNEIGHBORS_PER_DIM, 0) - wait_iwrite(n::Integer, A::ROCField{T}, i::Integer) where T <: GGNumber = AMDGPU.synchronize(rocstreams[n,i]); + wait_iwrite(n::Integer, A::ROCField{T}, i::Integer) where T <: GGNumber = AMDGPU.synchronize(rocstreams[n,i]; blocking=true); function allocate_rocstreams_iwrite(fields::GGField...) if length(fields) > size(rocstreams,2) # Note: for simplicity, we create a stream for every field even if it is not a ROCField @@ -169,7 +169,7 @@ let rocstreams = Array{AMDGPU.HIPStream}(undef, NNEIGHBORS_PER_DIM, 0) - wait_iread(n::Integer, A::ROCField{T}, i::Integer) where T <: GGNumber = AMDGPU.synchronize(rocstreams[n,i]); + wait_iread(n::Integer, A::ROCField{T}, i::Integer) where T <: GGNumber = AMDGPU.synchronize(rocstreams[n,i]; blocking=true); function allocate_rocstreams_iread(fields::GGField...) if length(fields) > size(rocstreams,2) # Note: for simplicity, we create a stream for every field even if it is not a ROCField diff --git a/src/CUDAExt/update_halo.jl b/src/CUDAExt/update_halo.jl index 27bdcf2..98bc054 100644 --- a/src/CUDAExt/update_halo.jl +++ b/src/CUDAExt/update_halo.jl @@ -150,7 +150,7 @@ let custreams = Array{CuStream}(undef, NNEIGHBORS_PER_DIM, 0) - wait_iwrite(n::Integer, A::CuField{T}, i::Integer) where T <: GGNumber = CUDA.synchronize(custreams[n,i]); + wait_iwrite(n::Integer, A::CuField{T}, i::Integer) where T <: GGNumber = CUDA.synchronize(custreams[n,i]; blocking=true); function allocate_custreams_iwrite(fields::GGField...) if length(fields) > size(custreams,2) # Note: for simplicity, we create a stream for every field even if it is not a CuField @@ -179,7 +179,7 @@ let custreams = Array{CuStream}(undef, NNEIGHBORS_PER_DIM, 0) - wait_iread(n::Integer, A::CuField{T}, i::Integer) where T <: GGNumber = CUDA.synchronize(custreams[n,i]); + wait_iread(n::Integer, A::CuField{T}, i::Integer) where T <: GGNumber = CUDA.synchronize(custreams[n,i]; blocking=true); function allocate_custreams_iread(fields::GGField...) if length(fields) > size(custreams,2) # Note: for simplicity, we create a stream for every field even if it is not a CuField diff --git a/src/update_halo.jl b/src/update_halo.jl index 4728488..e917506 100644 --- a/src/update_halo.jl +++ b/src/update_halo.jl @@ -26,22 +26,22 @@ Update the halo of the given GPU/CPU-array(s). shell> export IGG_ROCMAWARE_MPI=1 ``` """ -function update_halo!(A::Union{GGArray, GGField, GGFieldConvertible}...) +function update_halo!(A::Union{GGArray, GGField, GGFieldConvertible}...; dims=(NDIMS_MPI,(1:NDIMS_MPI-1)...)) check_initialized(); fields = wrap_field.(A); check_fields(fields...); - _update_halo!(fields...); # Assignment of A to fields in the internal function _update_halo!() as vararg A can consist of multiple fields; A will be used for a single field in the following (The args of update_halo! must however be "A..." for maximal simplicity and elegance for the user). + _update_halo!(fields...; dims=dims); # Assignment of A to fields in the internal function _update_halo!() as vararg A can consist of multiple fields; A will be used for a single field in the following (The args of update_halo! must however be "A..." for maximal simplicity and elegance for the user). return nothing end # -function _update_halo!(fields::GGField...) +function _update_halo!(fields::GGField...; dims=dims) if (!cuda_enabled() && !amdgpu_enabled() && !all_arrays(fields...)) error("not all arrays are CPU arrays, but no GPU extension is loaded.") end #NOTE: in the following, it is only required to check for `cuda_enabled()`/`amdgpu_enabled()` when the context does not imply `any_cuarray(fields...)` or `is_cuarray(A)` or the corresponding for AMDGPU. # NOTE: the case where only one of the two extensions are loaded, but an array dad would be for the other extension is passed is very unlikely and therefore not explicitly checked here (but could be added later). allocate_bufs(fields...); if any_array(fields...) allocate_tasks(fields...); end if any_cuarray(fields...) allocate_custreams(fields...); end if any_rocarray(fields...) allocate_rocstreams(fields...); end - for dim = 1:NDIMS_MPI # NOTE: this works for 1D-3D (e.g. if nx>1, ny>1 and nz=1, then for d=3, there will be no neighbors, i.e. nothing will be done as desired...). + for dim in dims # NOTE: this works for 1D-3D (e.g. if nx>1, ny>1 and nz=1, then for d=3, there will be no neighbors, i.e. nothing will be done as desired...). for ns = 1:NNEIGHBORS_PER_DIM, i = 1:length(fields) if has_neighbor(ns, dim) iwrite_sendbufs!(ns, dim, fields[i], i); end end