From 704d8f0e7f6377b61dd54b9755c1953e591212a8 Mon Sep 17 00:00:00 2001 From: Hoppe Date: Tue, 3 Dec 2024 16:17:29 +0100 Subject: [PATCH 1/6] created branch for QR in case split=0 and non-tall-skinny matrices --- heat/core/linalg/qr.py | 221 +++++++++++++++++++++-------------------- 1 file changed, 113 insertions(+), 108 deletions(-) diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 8fb26b204..c4a517122 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -181,121 +181,126 @@ def qr( return QR(Q, R) if A.split == A.ndim - 2: - # implementation of TS-QR for split = 0 # check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data) - if A.lshape_map[:, -2].max().item() < A.shape[-1]: - raise ValueError( - "A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." - ) - - current_procs = [i for i in range(A.comm.size)] - current_comm = A.comm - local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) - Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode) - R_loc = R_loc.contiguous() # required for all the communication ops lateron - if mode == "reduced": - leave_comm = current_comm.Split(current_comm.rank, A.comm.rank) - - level = 1 - while len(current_procs) > 1: - if A.comm.rank in current_procs and local_comm.size > 1: - # create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes - shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) - if local_comm.rank == 0: - gathered_R_loc = torch.zeros( - (*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), - device=R_loc.device, - dtype=R_loc.dtype, - ) - counts = list(shapes_R_loc) - displs = torch.cumsum( - torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0 - ).tolist()[:-1] - else: - gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) - counts = None - displs = None - # gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes - local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) - # perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc - if local_comm.rank == 0: - previous_shape = R_loc.shape - Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode) - R_loc = R_loc.contiguous() - else: - Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) - if mode == "reduced": + smallest_number_of_local_rows = A.lshape_map[:, -2].min().item() + if smallest_number_of_local_rows < A.shape[-1]: + # raise ValueError( + # "A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." + # ) + return 0 + else: + # in this case the input is tall-skinny and we apply the TS-QR algorithm + # it follows the implementation of TS-QR for split = 0 + current_procs = [i for i in range(A.comm.size)] + current_comm = A.comm + local_comm = current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) + Q_loc, R_loc = torch.linalg.qr(A.larray, mode=mode) + R_loc = R_loc.contiguous() # required for all the communication ops lateron + if mode == "reduced": + leave_comm = current_comm.Split(current_comm.rank, A.comm.rank) + + level = 1 + while len(current_procs) > 1: + if A.comm.rank in current_procs and local_comm.size > 1: + # create array to collect the R_loc's from all processes of the process group of at most n_procs_to_merge processes + shapes_R_loc = local_comm.gather(R_loc.shape[-2], root=0) if local_comm.rank == 0: - Q_buf = Q_buf.contiguous() - scattered_Q_buf = torch.empty( - R_loc.shape if local_comm.rank != 0 else previous_shape, - device=R_loc.device, - dtype=R_loc.dtype, - ) - # scatter the Q_buf to all processes of the process group - local_comm.Scatterv((Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2) - del gathered_R_loc, Q_buf - - # for each process in the current processes, broadcast the scattered_Q_buf of this process - # to all leaves (i.e. all original processes that merge to the current process) - if mode == "reduced" and leave_comm.size > 1: + gathered_R_loc = torch.zeros( + (*R_loc.shape[:-2], sum(shapes_R_loc), R_loc.shape[-1]), + device=R_loc.device, + dtype=R_loc.dtype, + ) + counts = list(shapes_R_loc) + displs = torch.cumsum( + torch.tensor([0] + shapes_R_loc, dtype=torch.int32), 0 + ).tolist()[:-1] + else: + gathered_R_loc = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) + counts = None + displs = None + # gather the R_loc's from all processes of the process group of at most n_procs_to_merge processes + local_comm.Gatherv(R_loc, (gathered_R_loc, counts, displs), root=0, axis=-2) + # perform QR decomposition on the concatenated, gathered R_loc's to obtain new R_loc + if local_comm.rank == 0: + previous_shape = R_loc.shape + Q_buf, R_loc = torch.linalg.qr(gathered_R_loc, mode=mode) + R_loc = R_loc.contiguous() + else: + Q_buf = torch.empty(0, device=R_loc.device, dtype=R_loc.dtype) + if mode == "reduced": + if local_comm.rank == 0: + Q_buf = Q_buf.contiguous() + scattered_Q_buf = torch.empty( + R_loc.shape if local_comm.rank != 0 else previous_shape, + device=R_loc.device, + dtype=R_loc.dtype, + ) + # scatter the Q_buf to all processes of the process group + local_comm.Scatterv( + (Q_buf, counts, displs), scattered_Q_buf, root=0, axis=-2 + ) + del gathered_R_loc, Q_buf + + # for each process in the current processes, broadcast the scattered_Q_buf of this process + # to all leaves (i.e. all original processes that merge to the current process) + if mode == "reduced" and leave_comm.size > 1: + try: + scattered_Q_buf_shape = scattered_Q_buf.shape + except UnboundLocalError: + scattered_Q_buf_shape = None + scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0) + if scattered_Q_buf_shape is not None: + # this is needed to ensure that only those Q_loc get updates that are actually part of the current process group + if leave_comm.rank != 0: + scattered_Q_buf = torch.empty( + scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype + ) + leave_comm.Bcast(scattered_Q_buf, root=0) + # update the local Q_loc by multiplying it with the scattered_Q_buf try: - scattered_Q_buf_shape = scattered_Q_buf.shape + Q_loc = Q_loc @ scattered_Q_buf + del scattered_Q_buf except UnboundLocalError: - scattered_Q_buf_shape = None - scattered_Q_buf_shape = leave_comm.bcast(scattered_Q_buf_shape, root=0) - if scattered_Q_buf_shape is not None: - # this is needed to ensure that only those Q_loc get updates that are actually part of the current process group - if leave_comm.rank != 0: - scattered_Q_buf = torch.empty( - scattered_Q_buf_shape, device=Q_loc.device, dtype=Q_loc.dtype + pass + + # update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering + current_procs = [ + current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0 + ] + if len(current_procs) > 1: + new_group = A.comm.group.Incl(current_procs) + current_comm = A.comm.Create_group(new_group) + if A.comm.rank in current_procs: + local_comm = communication.MPICommunication( + current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) ) - leave_comm.Bcast(scattered_Q_buf, root=0) - # update the local Q_loc by multiplying it with the scattered_Q_buf - try: - Q_loc = Q_loc @ scattered_Q_buf - del scattered_Q_buf - except UnboundLocalError: - pass - - # update: determine processes to be active at next "merging" level, create new communicator and split it into groups for gathering - current_procs = [ - current_procs[i] for i in range(len(current_procs)) if i % procs_to_merge == 0 - ] - if len(current_procs) > 1: - new_group = A.comm.group.Incl(current_procs) - current_comm = A.comm.Create_group(new_group) - if A.comm.rank in current_procs: - local_comm = communication.MPICommunication( - current_comm.Split(current_comm.rank // procs_to_merge, A.comm.rank) - ) - if mode == "reduced": - leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) - level += 1 - # broadcast the final R_loc to all processes - R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) - if A.comm.rank != 0: - R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) - A.comm.Bcast(R_loc, root=0) - R = DNDarray( - R_loc, - gshape=R_gshape, - dtype=A.dtype, - split=None, - device=A.device, - comm=A.comm, - balanced=True, - ) - if mode == "r": - Q = None - else: - Q = DNDarray( - Q_loc, - gshape=A.shape, + if mode == "reduced": + leave_comm = A.comm.Split(A.comm.rank // procs_to_merge**level, A.comm.rank) + level += 1 + # broadcast the final R_loc to all processes + R_gshape = (*A.shape[:-2], A.shape[-1], A.shape[-1]) + if A.comm.rank != 0: + R_loc = torch.empty(R_gshape, dtype=R_loc.dtype, device=R_loc.device) + A.comm.Bcast(R_loc, root=0) + R = DNDarray( + R_loc, + gshape=R_gshape, dtype=A.dtype, - split=A.split, + split=None, device=A.device, comm=A.comm, balanced=True, ) - return QR(Q, R) + if mode == "r": + Q = None + else: + Q = DNDarray( + Q_loc, + gshape=A.shape, + dtype=A.dtype, + split=A.split, + device=A.device, + comm=A.comm, + balanced=True, + ) + return QR(Q, R) From 7580552c9f01794b2cb9367294120795d8091c5e Mon Sep 17 00:00:00 2001 From: Hoppe Date: Wed, 4 Dec 2024 15:39:42 +0100 Subject: [PATCH 2/6] ... --- dev.py | 12 ++++++++++++ heat/core/linalg/qr.py | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 dev.py diff --git a/dev.py b/dev.py new file mode 100644 index 000000000..c09574a37 --- /dev/null +++ b/dev.py @@ -0,0 +1,12 @@ +""" +remove before merge +""" + +import heat as ht + +x = ht.random.rand(10, 10, split=0) + +print(x) + +q = ht.linalg.qr(x) +print(q) diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index c4a517122..a5acfc7b1 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -187,7 +187,11 @@ def qr( # raise ValueError( # "A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." # ) - return 0 + column_idx = torch.arange( + 0, A.shape[-1], smallest_number_of_local_rows, dtype=torch.int64 + ) + return column_idx + else: # in this case the input is tall-skinny and we apply the TS-QR algorithm # it follows the implementation of TS-QR for split = 0 From 6430afd4c419fe0efffad2d17a48d2dd0ce7e874 Mon Sep 17 00:00:00 2001 From: Hoppe Date: Mon, 9 Dec 2024 17:37:30 +0100 Subject: [PATCH 3/6] QR for split=0 and non tall-skinny data --- dev.py | 12 ------ heat/core/linalg/qr.py | 69 +++++++++++++++++++++++++------ heat/core/linalg/tests/test_qr.py | 14 ++++--- 3 files changed, 65 insertions(+), 30 deletions(-) delete mode 100644 dev.py diff --git a/dev.py b/dev.py deleted file mode 100644 index c09574a37..000000000 --- a/dev.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -remove before merge -""" - -import heat as ht - -x = ht.random.rand(10, 10, split=0) - -print(x) - -q = ht.linalg.qr(x) -print(q) diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index a5acfc7b1..8544c132f 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -7,6 +7,7 @@ from typing import Tuple from ..dndarray import DNDarray +from ..manipulations import concatenate from .. import factories from .. import communication from ..types import float32, float64 @@ -31,7 +32,6 @@ def qr( ---------- A : DNDarray of shape (M, N), of shape (...,M,N) in the batched case Array which will be decomposed. So far only arrays with datatype float32 or float64 are supported - For split=0 (-2, in the batched case), the matrix must be tall skinny, i.e. the local chunks of data must have at least as many rows as columns. mode : str, optional default "reduced" returns Q and R with dimensions (M, min(M,N)) and (min(M,N), N). Potential batch dimensions are not modified. "r" returns only R, with dimensions (min(M,N), N). @@ -46,13 +46,17 @@ def qr( - If ``A`` is distributed along the columns (A.split = 1), so will be ``Q`` and ``R``. - - If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`, but ``R`` won't be distributed, i.e. `R. split = None` and a full copy of ``R`` will be stored on each process. + - If ``A`` is distributed along the rows (A.split = 0), ``Q`` too will have `split=0`. ``R`` won't be distributed, i.e. `R. split = None`, if ``A`` is tall-skinny, i.e., if + the largest local chunk of data of ``A`` has at least as many rows as columns. Otherwise, ``R`` will be distributed along the rows as well, i.e., `R.split = 0`. Note that the argument `calc_q` allowed in earlier Heat versions is no longer supported; `calc_q = False` is equivalent to `mode = "r"`. Unlike ``numpy.linalg.qr()``, `ht.linalg.qr` only supports ``mode="reduced"`` or ``mode="r"`` for the moment, since "complete" may result in heavy memory usage. Heats QR function is built on top of PyTorchs QR function, ``torch.linalg.qr()``, using LAPACK (CPU) and MAGMA (CUDA) on - the backend. For split=0 (-2, in the batched case), tall-skinny QR (TS-QR) is implemented, while for split=1 (-1, in the batched case) a block-wise version of stabilized Gram-Schmidt orthogonalization is used. + the backend. Both cases split=0 and split=1 build on a column-block-wise version of stabilized Gram-Schmidt orthogonalization. + For split=1 (-1, in the batched case), this is directly applied to the local arrays of the input array. + For split=0, a tall-skinny QR (TS-QR) is implemented for the case of tall-skinny matrices (i.e., the largest local chunk of data has at least as many rows as columns), + and extended to non tall-skinny matrices by applying a block-wise version of stabilized Gram-Schmidt orthogonalization. References ----------- @@ -181,16 +185,57 @@ def qr( return QR(Q, R) if A.split == A.ndim - 2: - # check that data distribution is reasonable for TS-QR (i.e. tall-skinny matrix with also tall-skinny local chunks of data) - smallest_number_of_local_rows = A.lshape_map[:, -2].min().item() - if smallest_number_of_local_rows < A.shape[-1]: - # raise ValueError( - # "A is split along the rows and the local chunks of data are rectangular with more rows than columns. \n Applying TS-QR in this situation is not reasonable w.r.t. runtime and memory consumption. \n We recomment to split A along the columns instead. \n In case this is not an option for you, please open an issue on GitHub." - # ) - column_idx = torch.arange( - 0, A.shape[-1], smallest_number_of_local_rows, dtype=torch.int64 + # check that data distribution is reasonable for TS-QR + # we regard a matrix with split = 0 as suitable for TS-QR is largest local chunk of data has at least as many rows as columns + biggest_number_of_local_rows = A.lshape_map[:, -2].max().item() + if biggest_number_of_local_rows < A.shape[-1]: + column_idx = torch.cumsum(A.lshape_map[:, -2], 0) + column_idx = column_idx[column_idx < A.shape[-1]] + column_idx = torch.cat( + ( + torch.tensor([0], device=column_idx.device), + column_idx, + torch.tensor([A.shape[-1]], device=column_idx.device), + ) ) - return column_idx + A_copy = A.copy() + R = A.copy() + # Block-wise Gram-Schmidt orthogonalization, applied to groups of columns + offset = 1 if A.shape[-1] <= A.shape[-2] else 2 + for k in range(len(column_idx) - offset): + # since we only consider a group of columns, TS QR is applied to a tall-skinny matrix + Qnew, Rnew = qr( + A_copy[..., :, column_idx[k] : column_idx[k + 1]], + mode="reduced", + procs_to_merge=procs_to_merge, + ) + + # usual update of the remaining columns + if R.comm.rank == k: + R.larray[ + ..., + : (column_idx[k + 1] - column_idx[k]), + column_idx[k] : column_idx[k + 1], + ] = Rnew.larray + if R.comm.rank > k: + R.larray[..., :, column_idx[k] : column_idx[k + 1]] *= 0 + if k < len(column_idx) - 2: + coeffs = ( + torch.transpose(Qnew.larray, -2, -1) + @ A_copy.larray[..., :, column_idx[k + 1] :] + ) + R.comm.Allreduce(communication.MPI.IN_PLACE, coeffs) + if R.comm.rank == k: + R.larray[..., :, column_idx[k + 1] :] = coeffs + A_copy.larray[..., :, column_idx[k + 1] :] -= Qnew.larray @ coeffs + if mode == "reduced": + Q = Qnew if k == 0 else concatenate((Q, Qnew), axis=-1) + if A.shape[-1] < A.shape[-2]: + R = R[..., : A.shape[-1], :].balance() + if mode == "reduced": + return QR(Q, R) + else: + return QR(None, R) else: # in this case the input is tall-skinny and we apply the TS-QR algorithm diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index 08d130486..c149a97ce 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -76,7 +76,11 @@ def test_qr_split0(self): for procs_to_merge in [0, 2, 3]: for mode in ["reduced", "r"]: # split = 0 can be handeled only for tall skinny matrices s.t. the local chunks are at least square too - for shape in [(40 * ht.MPI_WORLD.size + 1, 40), (40 * ht.MPI_WORLD.size, 20)]: + for shape in [ + (20 * ht.MPI_WORLD.size + 1, 40 * ht.MPI_WORLD.size), + (20 * ht.MPI_WORLD.size, 20 * ht.MPI_WORLD.size), + (40 * ht.MPI_WORLD.size - 1, 20 * ht.MPI_WORLD.size), + ]: for dtype in [ht.float32, ht.float64]: dtypetol = 1e-3 if dtype == ht.float32 else 1e-6 mat = ht.random.randn(*shape, dtype=dtype, split=split) @@ -147,7 +151,9 @@ def test_batched_qr_split1(self): def test_batched_qr_split0(self): # one batch dimension, float32 data type, "split = 0" (second last dimension) - x = ht.random.randn(8, ht.MPI_WORLD.size * 10 + 3, 9, dtype=ht.float32, split=1) + x = ht.random.randn( + 8, ht.MPI_WORLD.size * 10 + 3, ht.MPI_WORLD.size * 10 - 1, dtype=ht.float32, split=1 + ) q, r = ht.linalg.qr(x) batched_id = ht.stack([ht.eye(q.shape[2], dtype=ht.float32) for _ in range(q.shape[0])]) @@ -178,7 +184,3 @@ def test_wronginputs(self): # test wrong dtype with self.assertRaises(TypeError): ht.linalg.qr(ht.zeros((10, 10), dtype=ht.int32)) - # test wrong shape for split=0 - if ht.MPI_WORLD.size > 1: - with self.assertRaises(ValueError): - ht.linalg.qr(ht.zeros((10, 10), split=0)) From 3a770701108f39b890cfc67fd715addc7f4cb31f Mon Sep 17 00:00:00 2001 From: Hoppe Date: Tue, 10 Dec 2024 10:18:21 +0100 Subject: [PATCH 4/6] debugging --- heat/core/linalg/tests/test_qr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index c149a97ce..f5c9b25ec 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -150,6 +150,7 @@ def test_batched_qr_split1(self): self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6)) def test_batched_qr_split0(self): + ht.random.seed(424242) # one batch dimension, float32 data type, "split = 0" (second last dimension) x = ht.random.randn( 8, ht.MPI_WORLD.size * 10 + 3, ht.MPI_WORLD.size * 10 - 1, dtype=ht.float32, split=1 From 2c98fdaec2fcf758226720b9acfc200e28c70a0c Mon Sep 17 00:00:00 2001 From: Hoppe Date: Fri, 14 Feb 2025 10:40:11 +0100 Subject: [PATCH 5/6] typo corrected according to review --- heat/core/linalg/qr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 8544c132f..7c62258f3 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -186,7 +186,7 @@ def qr( if A.split == A.ndim - 2: # check that data distribution is reasonable for TS-QR - # we regard a matrix with split = 0 as suitable for TS-QR is largest local chunk of data has at least as many rows as columns + # we regard a matrix with split = 0 as suitable for TS-QR if its largest local chunk of data has at least as many rows as columns biggest_number_of_local_rows = A.lshape_map[:, -2].max().item() if biggest_number_of_local_rows < A.shape[-1]: column_idx = torch.cumsum(A.lshape_map[:, -2], 0) From 3465c7ce2aa75bfa1e256ed2de40d2bcb5522339 Mon Sep 17 00:00:00 2001 From: Hoppe Date: Fri, 14 Feb 2025 11:00:08 +0100 Subject: [PATCH 6/6] added benchmark for qr with split 0 but non-tall skinny shape --- benchmarks/cb/linalg.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/benchmarks/cb/linalg.py b/benchmarks/cb/linalg.py index 3596d4916..9202ca0d5 100644 --- a/benchmarks/cb/linalg.py +++ b/benchmarks/cb/linalg.py @@ -19,6 +19,11 @@ def qr_split_0(a): qr = ht.linalg.qr(a) +@monitor() +def qr_split_0_square(a): + qr = ht.linalg.qr(a) + + @monitor() def qr_split_1(a): qr = ht.linalg.qr(a) @@ -57,6 +62,11 @@ def run_linalg_benchmarks(): qr_split_0(a_0) del a_0 + n = 2000 + a_0 = ht.random.random((n, n), split=0) + qr_split_0_square(a_0) + del a_0 + n = 2000 a_1 = ht.random.random((n, n), split=1) qr_split_1(a_1)