diff --git a/.github/Pythia_saturation.png b/.github/Pythia_saturation.png new file mode 100644 index 0000000..f68e0ed Binary files /dev/null and b/.github/Pythia_saturation.png differ diff --git a/.github/TinyLlama_logo.png b/.github/TinyLlama_logo.png new file mode 100644 index 0000000..3f2c570 Binary files /dev/null and b/.github/TinyLlama_logo.png differ diff --git a/.github/llama2-training.png b/.github/llama2-training.png new file mode 100644 index 0000000..c4e4993 Binary files /dev/null and b/.github/llama2-training.png differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..59b43e4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +__pycache__ +.idea +.DS_Store +*.egg-info +build +.venv +.vscode + +# data +data +checkpoints +out +wandb + +tests/original_falcon_40b.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..fe60df9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2023] Lightning AI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/PRETRAIN.md b/PRETRAIN.md new file mode 100644 index 0000000..46fe588 --- /dev/null +++ b/PRETRAIN.md @@ -0,0 +1,81 @@ +## Pretrain TinyLlama + +### Installation +We expect you have CUDA 11.8 installed. +#### Install Pytorch Nightly. +```bash +pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' +``` +#### Build XFormers from Source +Note: as of 2023/09/02, xformers does not provide pre-built binaries for torch 2.1. You have to build it from source. +```bash +pip uninstall ninja -y && install ninja -U +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +``` + + +#### Install Flash-Attention 2 and other fused operators: +```bash +git clone https://github.com/Dao-AILab/flash-attention +cd flash-attention +python setup.py install +cd csrc/rotary && pip install . +cd ../layer_norm && pip install . +cd ../xentropy && pip install . +cd ../.. && rm -rf flash-attention +``` +#### Install Remaining Dependencies +``` +pip install -r requirements.txt tokenizers sentencepiece +``` +to install other dependencies. +It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings. + +Then you are ready to go 🎉! + +### Data Preparation + +#### Download Datasets +Download the Slimpajama and Starcoderdata datasets to your chosen directory. +```bash +cd /path/to/dataset +git lfs install +git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B +git clone https://huggingface.co/datasets/bigcode/starcoderdata +``` +The SlimPajama dataset eats 893GB diskspace and the starcoderdata takes 290GB. + +#### Tokenize data +Use the provided scripts to tokenize the datasets and divide them into chunks. +```bash +python scripts/prepare_starcoder.py --source_path /path/to/starcoderdata/ --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 +python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split validation --percentage 1.0 +python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 +``` +The processed data will take 1.8T storage. + +### Pretraining +If your setup comprises two nodes, each with 8 GPUs, you can initiate pretraining with the following commands: + +On node 1: +``` +lightning run model \ + --node-rank=0 \ + --main-address=172.16.101.5 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star +``` +On node 2: +``` +lightning run model \ + --node-rank=1 \ + --main-address=172.16.101.5 \ + --accelerator=cuda \ + --devices=8 \ + --num-nodes=2 \ + pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star +``` +You can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) if you have a slurm cluster. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..ff540f4 --- /dev/null +++ b/README.md @@ -0,0 +1,159 @@ +
self._chunk_size: + part_len = self._chunk_size - self._idx + self._arr[self._idx : self._idx + part_len] = arr[:part_len] + self._write_chunk() + arr = arr[part_len:] + + arr_len = arr.shape[0] + self._arr[self._idx : self._idx + arr_len] = arr + self._idx += arr_len + + def write_reminder(self): + self._write_chunk() + + +class PackedDatasetIterator: + def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): + self._seed = seed + self._shuffle = shuffle + self._rng = np.random.default_rng(seed) if shuffle else None + self._block_idxs = None + + self._wrap = wrap + + # TODO: instead of filenames, we could have a single text stream + # (or text file) with the sequence of all files to be + # fetched/loaded. + self._filenames = filenames + self._file_idx = 0 + + self._n_chunks = n_chunks + + self._dtype = None + self._block_size = block_size + self._n_blocks = None + + self._mmaps = [] + self._buffers = [] + + self._block_idxs = [] + self._curr_idx = 0 + + self._load_n_chunks() + + def _read_header(self, path): + with open(path, "rb") as f: + magic = f.read(len(HDR_MAGIC)) + assert magic == HDR_MAGIC, "File doesn't match expected format." + version = struct.unpack("len(self._filenames[self._file_idx :]): + # if not self._wrap: + # raise StopIteration + self._file_idx = 0 + + for i in range(self._n_chunks): + filename = self._filenames[self._file_idx + i] + if self._dtype is None: + self._dtype, self._chunk_size = self._read_header(filename) + self._n_blocks = self._chunk_size // self._block_size + # TODO: check header matches with previous files + mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) + self._mmaps.append(mmap) + self._buffers.append(memoryview(mmap)) + + self._file_idx += self._n_chunks + n_all_blocks = self._n_chunks * self._n_blocks + + self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) + + self._curr_idx = 0 + + def __del__(self): + self._close_mmaps() + del self._mmaps + del self._buffers + + def __iter__(self): + return self + + def __next__(self): + if self._curr_idx >= len(self._block_idxs): + self._load_n_chunks() + # TODO: trigger fetching next next n_chunks if remote + block_idx = self._block_idxs[self._curr_idx] + chunk_id = block_idx // self._n_blocks + buffer = self._buffers[chunk_id] + elem_id = (block_idx % self._n_blocks) * self._block_size + offset = np.dtype(self._dtype).itemsize * elem_id + arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) + self._curr_idx += 1 + return torch.from_numpy(arr.astype(np.int64)) + + +class CombinedDataset(IterableDataset): + def __init__(self, datasets, seed, weights=None): + self._seed = seed + self._datasets = datasets + self._weights = weights + n_datasets = len(datasets) + if weights is None: + self._weights = [1 / n_datasets] * n_datasets + + def __iter__(self): + return CombinedDatasetIterator(self._datasets, self._seed, self._weights) + + +class CombinedDatasetIterator: + def __init__(self, datasets, seed, weights): + self._datasets = [iter(el) for el in datasets] + self._weights = weights + self._rng = random.Random(seed) + + def __next__(self): + (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) + return next(dataset) diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py new file mode 100644 index 0000000..1c7362a --- /dev/null +++ b/lit_gpt/rmsnorm.py @@ -0,0 +1,842 @@ +import torch +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16 + +import dropout_layer_norm +import torch +from torch.nn import init + + +def maybe_align(x, alignment_in_bytes=16): + """Assume that x already has last dim divisible by alignment_in_bytes""" + # TD [2023-07-04] I'm not 100% sure that clone will align the memory + # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 + return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() + + +def _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + rowscale, + colscale, + None, + None, + dropout_p, + epsilon, + 1.0, + 0, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + None, + None, + dropout_p, + 1.0, + 0, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(-1, hidden_size) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma0.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( + x0mat, + x1mat, + residualmat, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask0 and dmask1 are None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma + + +def _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + """ + hidden_size = gamma0.numel() + xmat = x.view((-1, hidden_size)) + dz0mat = dz0.view(xmat.shape) + dz1mat = dz1.view(xmat.shape) if dz1 is not None else None + dxmat = dx.view(xmat.shape) if dx is not None else None + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + *rest, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( + dz0mat, + dz1mat, + dxmat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 + + +class DropoutAddLayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + ctx.save_for_backward( + xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + if not return_dmask: + return ( + zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) + ) + else: + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return ( + (zmat.view(x0.shape), dmask) + if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) + ) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + None, + dcolscale, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormSubsetFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + x_shape = (-1, *x0.shape[1:]) + ctx.save_for_backward( + xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.rowscale_const = rowscale_const + ctx.x0_numrows = x0.shape[:-1].numel() + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + z_shape = (-1, *x0.shape[1:]) + if not return_dmask: + return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) + else: + z = zmat.view(z_shape) + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + ctx.rowscale_const, + ctx.x0_numrows, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(-1, *x.shape[1:]) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + dcolscale, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma0 = maybe_align(gamma0.contiguous(), 16) + beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None + gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None + beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_x1 = x1 is not None + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta0 is not None + z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) + if not return_dmask: + return z if not prenorm else (*z, xmat.view(x0.shape)) + else: + dmask0 = ( + dmask0.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + dmask1 = ( + dmask1.view(x0.shape) + if dropout_p > 0.0 and x1 is not None + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask0) + ctx.mark_non_differentiable(dmask1) + return ( + (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) + ) + + @staticmethod + def backward(ctx, dz0, dz1, *args): + dz0 = maybe_align(dz0.contiguous(), 16) # this happens! + dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_x1 = ctx.has_x1 + has_residual = ctx.has_residual + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + ) = _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + return ( + dx0, + dx1, + dresidual, + dgamma0, + dbeta0 if ctx.has_beta else None, + dgamma1, + dbeta1 if ctx.has_beta else None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm(x, weight, bias, epsilon): + return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) + + +def dropout_add_layer_norm( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + rowscale=None, + layerscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, + residual, + weight, + bias, + rowscale, + layerscale, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_subset( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + layerscale=None, + x0_subset=None, + out_subset=None, + rowscale_const=1.0, + out_numrows=0, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, + residual, + weight, + bias, + layerscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_parallel_residual( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +class DropoutAddLayerNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + prenorm=False, + p=0.0, + eps=1e-5, + residual_in_fp32=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x0, residual=None): + return dropout_add_layer_norm( + x0, + residual, + self.weight, + self.bias, + self.p if self.training else 0.0, + self.eps, + prenorm=self.prenorm, + residual_in_fp32=self.residual_in_fp32, + ) + +def rms_norm(x, weight, epsilon): + return DropoutAddLayerNormFn.apply( + x, None, weight, None, None, None, 0.0, epsilon, False, False, True + ) +class FusedRMSNorm(torch.nn.Module): + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(size)) + self.dim = dim + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return self.weight * x_normed + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) diff --git a/lit_gpt/speed_monitor.py b/lit_gpt/speed_monitor.py new file mode 100644 index 0000000..fa81b18 --- /dev/null +++ b/lit_gpt/speed_monitor.py @@ -0,0 +1,408 @@ +import time +from collections import deque +from contextlib import nullcontext +from typing import Any, Callable, Deque, Dict, Optional + +import torch +from lightning import Callback, Fabric, LightningModule, Trainer +from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only +from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only +from torch.utils.flop_counter import FlopCounterMode +import math +from lit_gpt import GPT, Config +from lit_gpt.utils import num_parameters + +GPU_AVAILABLE_FLOPS = { + # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + # nvidia publishes spec sheet with a 2x sparsity factor + "h100-sxm": { + "64-true": 67e12, + "32-true": 67e12, + "16-true": 1.979e15 / 2, + "16-mixed": 1.979e15 / 2, + "bf16-true": 1.979e15 / 2, + "bf16-mixed": 1.979e15 / 2, + "8-true": 3.958e15 / 2, + "8-mixed": 3.958e15 / 2, + }, + "h100-pcie": { + "64-true": 51e12, + "32-true": 51e12, + "16-true": 1.513e15 / 2, + "16-mixed": 1.513e15 / 2, + "bf16-true": 1.513e15 / 2, + "bf16-mixed": 1.513e15 / 2, + "8-true": 3.026e15 / 2, + "8-mixed": 3.026e15 / 2, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + "a100": { + "64-true": 19.5e12, + "32-true": 19.5e12, + "16-true": 312e12, + "16-mixed": 312e12, + "bf16-true": 312e12, + "bf16-mixed": 312e12, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf + "a10g": {"32-true": 31.2e12, "16-true": 125e12, "16-mixed": 125e12, "bf16-true": 125e12, "bf16-mixed": 125e12}, + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + "v100-sxm": {"64-true": 7.8e12, "32-true": 15.7e12, "16-true": 125e12, "16-mixed": 125e12}, + "v100-pcie": {"64-true": 7e12, "32-true": 14e12, "16-true": 112e12, "16-mixed": 112e12}, + "v100s-pcie": {"64-true": 8.2e12, "32-true": 16.4e12, "16-true": 130e12, "16-mixed": 130e12}, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + "t4": {"32-true": 8.1e12, "16-true": 65e12, "16-mixed": 65e12, "8-true": 130e12, "int4": 260e12}, + # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf + "quadro rtx 5000": {"32-true": 11.2e12, "16-true": 89.2e12, "16-mixed": 89.2e12}, +} + +TPU_AVAILABLE_FLOPS = { + # flop count for each TPU generation is the same for all precisions + # since bfloat16 precision is always used for performing matrix operations + # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16 + # source: https://arxiv.org/pdf/1907.10701.pdf + "v2": 45e12, + # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3 + "v3": 123e12, + # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4 + "v4": 275e12, +} + + +def get_flops_available(device: torch.device, precision: str) -> Optional[float]: + if device.type == "cuda": + device_name = torch.cuda.get_device_name(device).lower() + if "h100" in device_name and "hbm3" in device_name: + device_name = "h100-sxm" + elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): + device_name = "h100-pcie" + elif "a100" in device_name: + device_name = "a100" + elif "a10g" in device_name: + device_name = "a10g" + elif "v100-sxm" in device_name: + device_name = "v100-sxm" + elif "v100-pcie" in device_name: + device_name = "v100-pcie" + elif "t4" in device_name: + device_name = "t4" + elif "quadro rtx 5000" in device_name: + device_name = "quadro rtx 5000" + else: + device_name = None + + if device_name is not None: + try: + return int(GPU_AVAILABLE_FLOPS[device_name][precision]) + except KeyError: + raise KeyError( + f"flop count not found for {device_name} with precision: {precision}; " + "MFU cannot be calculated and reported." + ) + elif device.type == "xla": + from torch_xla.experimental import tpu + + device_name = tpu.get_tpu_env()["TYPE"].lower() + try: + return int(TPU_AVAILABLE_FLOPS[device_name]) + except KeyError: + raise KeyError( + f"flop count not found for {device_name} with precision: {precision}; " + "MFU cannot be calculated and reported." + ) + + return None + + +# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py + + +class SpeedMonitorBase: + """Logs the training throughput and utilization. + + +-------------------------------------+-----------------------------------------------------------+ + | Key | Logged data | + +=====================================+===========================================================+ + | | Rolling average (over `window_size` most recent | + | `throughput/batches_per_sec` | batches) of the number of batches processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/samples_per_sec` | batches) of the number of samples processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | + | | This may include padding depending on dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | Estimates flops by `flops_per_batch * batches_per_sec` | + | `throughput/flops_per_sec` | | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/tokens_per_sec` divided by world size. This | + | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/flops_per_sec` divided by world size. Only | + | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/device/flops_per_sec` divided by world size. | + | `throughput/device/mfu` | | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `time/train` | Total elapsed training time | + +-------------------------------------+-----------------------------------------------------------+ + | `time/val` | Total elapsed validation time | + +-------------------------------------+-----------------------------------------------------------+ + | `time/total` | Total elapsed time (time/train + time/val) | + +-------------------------------------+-----------------------------------------------------------+ + + Notes: + - The implementation assumes that devices are homogeneous as it normalizes by the world size. + - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or + batches/sec to measure throughput under this circumstance. + - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``. + There is no widespread, realistic, and reliable implementation to compute them. + We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which + will almost always be an overestimate when compared to the true value. + + Args: + window_size (int, optional): Number of batches to use for a rolling average of throughput. + Defaults to 100. + time_unit (str, optional): Time unit to use for `time` logging. Can be one of + 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. + """ + + def __init__( + self, + flops_available: float, + log_dict: Callable[[Dict, int], None], + window_size: int = 100, + time_unit: str = "hours", + log_iter_interval: int = 1, + ): + self.flops_available = flops_available + self.log_dict = log_dict + self.log_iter_interval = log_iter_interval + # Track the batch num samples and wct to compute throughput over a window of batches + self.history_samples: Deque[int] = deque(maxlen=window_size + 1) + self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval) + self.history_wct: Deque[float] = deque(maxlen=window_size + 1) + self.history_lengths: Deque[int] = deque(maxlen=window_size + 1) + self.history_flops: Deque[int] = deque(maxlen=window_size + 1) + + self.divider = 1 + if time_unit == "seconds": + self.divider = 1 + elif time_unit == "minutes": + self.divider = 60 + elif time_unit == "hours": + self.divider = 60 * 60 + elif time_unit == "days": + self.divider = 60 * 60 * 24 + else: + raise ValueError( + f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".' + ) + + # Keep track of time spent evaluating + self.total_eval_wct = 0.0 + self.iter = -1 + + def on_train_batch_end( + self, + samples: int, # total samples seen (per device) + train_elapsed: float, # total training time (seconds) + world_size: int, + flops_per_batch: Optional[int] = None, # (per device) + lengths: Optional[int] = None, # total length of the samples seen (per device) + train_loss: Optional[float] = None, + ): + self.iter += 1 + metrics = {} + + self.history_samples.append(samples) + self.history_training_loss.append(train_loss) + if lengths is not None: + self.history_lengths.append(lengths) + # if lengths are passed, there should be as many values as samples + assert len(self.history_samples) == len(self.history_lengths) + self.history_wct.append(train_elapsed) + if len(self.history_wct) == self.history_wct.maxlen: + elapsed_batches = len(self.history_samples) - 1 + elapsed_samples = self.history_samples[-1] - self.history_samples[0] + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + samples_per_sec = elapsed_samples * world_size / elapsed_wct + dev_samples_per_sec = elapsed_samples / elapsed_wct + metrics.update( + { + "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct, + "throughput/samples_per_sec": samples_per_sec, + "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct, + "throughput/device/samples_per_sec": dev_samples_per_sec, + } + ) + if lengths is not None: + elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0]) + avg_length = elapsed_lengths / elapsed_batches + metrics.update( + { + "throughput/tokens_per_sec": samples_per_sec * avg_length, + "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length, + "total_tokens": avg_length * world_size * samples, + } + ) + if train_loss is not None: + avg_loss = sum(self.history_training_loss) / len(self.history_training_loss) + metrics.update( + { + "metric/train_loss": avg_loss, + "metric/train_ppl": math.exp(avg_loss) + } + ) + + if flops_per_batch is not None: + # sum of flops per batch across ranks + self.history_flops.append(flops_per_batch * world_size) + if len(self.history_flops) == self.history_flops.maxlen: + elapsed_flops = sum(self.history_flops) - self.history_flops[0] + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + flops_per_sec = elapsed_flops / elapsed_wct + device_flops_per_sec = flops_per_sec / world_size + metrics.update( + {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec} + ) + if self.flops_available: + metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available + + metrics.update( + { + "time/train": train_elapsed / self.divider, + "time/val": self.total_eval_wct / self.divider, + "time/total": (train_elapsed + self.total_eval_wct) / self.divider, + "samples": samples, + } + ) + if self.iter % self.log_iter_interval == 0: + self.log_dict(metrics, self.iter//self.log_iter_interval) + + def eval_end(self, eval_elapsed: float): + self.total_eval_wct += eval_elapsed # seconds + + +class SpeedMonitorFabric(SpeedMonitorBase): + def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None: + # TODO: this will not work properly if a precision plugin is passed to Fabric + flops_available = get_flops_available(fabric.device, fabric._connector._precision_input) + super().__init__(flops_available, fabric.log_dict, *args, **kwargs) + + @fabric_rank_zero_only + def on_train_batch_end(self, *args: Any, **kwargs: Any): + super().on_train_batch_end(*args, **kwargs) + + +class SpeedMonitorCallback(Callback): + def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None: + super().__init__() + self.speed_monitor: Optional[SpeedMonitorBase] = None + self.speed_monitor_kwargs = kwargs + self.length_fn = length_fn + self.batch_size = batch_size + self.eval_t0: int = 0 + self.train_t0: int = 0 + self.total_lengths: int = 0 + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + if self.speed_monitor is not None: + return # already setup + # TODO: this will not work properly if a precision plugin is passed to Trainer + flops_available = get_flops_available( + trainer.strategy.root_device, trainer._accelerator_connector._precision_flag + ) + self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs) + + @trainer_rank_zero_only + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + if trainer.fit_loop._should_accumulate(): + return + + self.train_t0 = time.perf_counter() + + @trainer_rank_zero_only + def on_train_batch_end( + self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int + ) -> None: + self.total_lengths += self.length_fn(batch) + if trainer.fit_loop._should_accumulate(): + return + train_elapsed = time.perf_counter() - self.train_t0 + assert self.speed_monitor is not None + iter_num = trainer.fit_loop.total_batch_idx + assert (measured_flops := pl_module.measured_flops) is not None + self.speed_monitor.on_train_batch_end( + (iter_num + 1) * self.batch_size, + train_elapsed, + # this assumes that device FLOPs are the same and that all devices have the same batch size + trainer.world_size, + flops_per_batch=measured_flops, + lengths=self.total_lengths, + ) + + @trainer_rank_zero_only + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.eval_t0 = time.perf_counter() + + @trainer_rank_zero_only + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + eval_elapsed = time.perf_counter() - self.eval_t0 + assert self.speed_monitor is not None + self.speed_monitor.eval_end(eval_elapsed) + + +def flops_per_param(config: Config, n_params: int) -> int: + flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation + # this assumes that all samples have a fixed length equal to the block size + # which is most likely false during finetuning + flops_per_seq = flops_per_token * config.block_size + attn_flops_per_seq = config.n_layer * 2 * 2 * (config.n_embd * (config.block_size**2)) + return flops_per_seq + attn_flops_per_seq + + +def estimate_flops(model: GPT) -> int: + """Measures estimated FLOPs for MFU. + + Refs: + * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 + * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 + """ + # using all parameters for this is a naive over estimation because not all model parameters actually contribute to + # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage + # (~10%) compared to the measured FLOPs, making those lower but more realistic. + # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. + n_trainable_params = num_parameters(model, requires_grad=True) + trainable_flops = flops_per_param(model.config, n_trainable_params) + # forward + backward + gradients (assumes no gradient accumulation) + ops_per_step = 3 if model.training else 1 + n_frozen_params = num_parameters(model, requires_grad=False) + frozen_flops = flops_per_param(model.config, n_frozen_params) + # forward + backward + frozen_ops_per_step = 2 if model.training else 1 + return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops + + +def measure_flops(model: GPT, x: torch.Tensor) -> int: + """Measures real FLOPs for HFU""" + flop_counter = FlopCounterMode(model, display=False) + ctx = nullcontext() if model.training else torch.no_grad() + with ctx, flop_counter: + y = model(x) + if model.training: + y.sum().backward() + return flop_counter.get_total_flops() diff --git a/lit_gpt/tokenizer.py b/lit_gpt/tokenizer.py new file mode 100644 index 0000000..a076c13 --- /dev/null +++ b/lit_gpt/tokenizer.py @@ -0,0 +1,77 @@ +import json +from pathlib import Path +from typing import Optional + +import torch + + +class Tokenizer: + def __init__(self, checkpoint_dir: Path) -> None: + # some checkpoints have both files, `.model` takes precedence + if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): + from sentencepiece import SentencePieceProcessor + + self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) + self.backend = "sentencepiece" + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): + from tokenizers import Tokenizer as HFTokenizer + + self.processor = HFTokenizer.from_file(str(vocabulary_path)) + self.backend = "huggingface" + with open(checkpoint_dir / "tokenizer_config.json") as fp: + config = json.load(fp) + bos_token = config.get("bos_token") + self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None + self.eos_id = self.token_to_id(config["eos_token"]) + else: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + if self.backend == "huggingface": + return self.processor.get_vocab_size(with_added_tokens=False) + if self.backend == "sentencepiece": + return self.processor.vocab_size() + raise RuntimeError + + def token_to_id(self, token: str) -> int: + if self.backend == "huggingface": + id_ = self.processor.token_to_id(token) + elif self.backend == "sentencepiece": + id_ = self.processor.piece_to_id(token) + else: + raise RuntimeError + if id_ is None: + raise ValueError(f"token {token!r} not found in the collection.") + return id_ + + def encode( + self, + string: str, + device: Optional[torch.device] = None, + bos: bool = False, + eos: bool = False, + max_length: int = -1, + ) -> torch.Tensor: + if self.backend == "huggingface": + tokens = self.processor.encode(string).ids + elif self.backend == "sentencepiece": + tokens = self.processor.encode(string) + else: + raise RuntimeError + if bos: + bos_id = self.bos_id + if bos_id is None: + raise NotImplementedError("This tokenizer does not defined a bos token") + tokens = [bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tensor: torch.Tensor) -> str: + tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() + return self.processor.decode(tokens) diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py new file mode 100644 index 0000000..d1d7bc6 --- /dev/null +++ b/lit_gpt/utils.py @@ -0,0 +1,505 @@ +"""Utility functions for training and inference.""" + +import pickle +import sys +import warnings +from contextlib import contextmanager +from functools import partial +from io import BytesIO +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union + +import torch +import torch.nn as nn +import torch.utils._device +from lightning.fabric.loggers import CSVLogger +from torch.serialization import normalize_storage_type + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: + return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) + + +@contextmanager +def quantization(mode: Optional[str] = None): + if mode is None: + yield + return + + if mode == "bnb.int8": + from quantize.bnb import InferenceLinear8bitLt + + quantized_linear_cls = InferenceLinear8bitLt + elif mode == "bnb.fp4": + from quantize.bnb import Linear4bit + + # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "bnb.fp4-dq": + from quantize.bnb import Linear4bit + + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "bnb.nf4": + from quantize.bnb import Linear4bit + + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "bnb.nf4-dq": + from quantize.bnb import Linear4bit + + class QuantizedLinear(Linear4bit): + def __init__(self, *args, **kwargs): + super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) + + quantized_linear_cls = QuantizedLinear + elif mode == "gptq.int4": + from quantize.gptq import ColBlockQuantizedLinear + + class QuantizedLinear(ColBlockQuantizedLinear): + def __init__(self, *args, **kwargs): + super().__init__(*args, bits=4, tile_cols=-1, **kwargs) + + quantized_linear_cls = QuantizedLinear + else: + raise ValueError(f"Unknown quantization mode: {mode}") + + torch_linear_cls = torch.nn.Linear + torch.nn.Linear = quantized_linear_cls + yield + torch.nn.Linear = torch_linear_cls + + +# this is taken from torchhacks https://github.com/lernapparat/torchhacks + + +class NotYetLoadedTensor: + def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): + self.metatensor = metatensor + self.archiveinfo = archiveinfo + self.storageinfo = storageinfo + self.rebuild_args = rebuild_args + + @classmethod + def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): + ret = func(*args) + if isinstance(ret, NotYetLoadedTensor): + old_lt = ret._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) + + ret._load_tensor = _load_tensor + return ret + return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) + + @classmethod + def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): + if isinstance(data, NotYetLoadedTensor): + old_lt = data._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) + + data._load_tensor = _load_tensor + return data + return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) + + @classmethod + def rebuild_tensor_v2( + cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None + ): + rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) + metatensor = torch._utils._rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata + ) + storageinfo = storage.archiveinfo + return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) + + def _load_tensor(self): + name, storage_cls, fn, device, size = self.storageinfo + dtype = self.metatensor.dtype + + uts = ( + self.archiveinfo.zipfile_context.zf.get_storage_from_record( + f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage + ) + ._typed_storage() + ._untyped_storage + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) + return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] + return func(*loaded_args, **kwargs) + # gc.collect would be costly here, maybe do it optionally + + def __getattr__(self, name): + # properties + ## TODO: device, is_...?? + ## TODO: mH, mT, H, T, data, imag, real + ## name ??? + if name in { + "dtype", + "grad", + "grad_fn", + "layout", + "names", + "ndim", + "output_nr", + "requires_grad", + "retains_grad", + "shape", + "volatile", + }: + return getattr(self.metatensor, name) + if name in {"size"}: + return getattr(self.metatensor, name) + # materializing with contiguous is needed for quantization + if name in {"contiguous"}: + return getattr(self._load_tensor(), name) + + raise AttributeError(f"{type(self)} does not have {name}") + + def __repr__(self): + return f"NotYetLoadedTensor({repr(self.metatensor)})" + + +class LazyLoadingUnpickler(pickle.Unpickler): + def __init__(self, file, zipfile_context): + super().__init__(file) + self.zipfile_context = zipfile_context + + def find_class(self, module, name): + res = super().find_class(module, name) + if module == "torch._utils" and name == "_rebuild_tensor_v2": + return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) + if module == "torch._tensor" and name == "_rebuild_from_type_v2": + return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) + if module == "torch._utils" and name == "_rebuild_parameter": + return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) + return res + + def persistent_load(self, pid): + name, cls, fn, device, size = pid + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") + s.archiveinfo = pid + return s + + +class lazy_load: + def __init__(self, fn): + self.zf = torch._C.PyTorchFileReader(str(fn)) + with BytesIO(self.zf.get_record("data.pkl")) as pkl: + mup = LazyLoadingUnpickler(pkl, self) + self.sd = mup.load() + + def __enter__(self): + return self.sd + + def __exit__(self, exc_type, exc_val, exc_tb): + del self.zf # I don't think there is a way to force closing... + self.zf = None + + +def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: + files = { + "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), + "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), + "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( + checkpoint_dir / "tokenizer.model" + ).is_file(), + "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), + } + if checkpoint_dir.is_dir(): + if all(files.values()): + # we're good + return + problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" + else: + problem = " is not a checkpoint directory" + + # list locally available checkpoints + available = list(Path("checkpoints").glob("*/*")) + if available: + options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) + extra = f"\nYou have downloaded locally:{options}\n" + else: + extra = "" + + error_message = ( + f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." + "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" + f"{extra}\nSee all download options by running:\n python scripts/download.py" + ) + print(error_message, file=sys.stderr) + raise SystemExit(1) + + +class SavingProxyForStorage: + def __init__(self, obj, saver, protocol_version=5): + self.protocol_version = protocol_version + self.saver = saver + if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): + raise TypeError(f"expected storage, not {type(obj)}") + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + if isinstance(obj, torch.storage.TypedStorage): + # PT upstream wants to deprecate this eventually... + storage = obj._untyped_storage + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + else: + storage = obj + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + storage_key = saver._write_storage_and_return_key(storage) + location = torch.serialization.location_tag(storage) + + self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) + + def __reduce_ex__(self, protocol_version): + assert False, "this should be handled with out of band" + + +class SavingProxyForTensor: + def __init__(self, tensor, saver, protocol_version=5): + self.protocol_version = protocol_version + self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) + assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" + storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) + self.reduce_args = (storage_proxy, *other_reduce_args) + + def __reduce_ex__(self, protocol_version): + if protocol_version != self.protocol_version: + raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") + return self.reduce_ret_fn, self.reduce_args + + +class IncrementalPyTorchPickler(pickle.Pickler): + def __init__(self, saver, *args, **kwargs): + super().__init__(*args, **kwargs) + self.storage_dtypes = {} + self.saver = saver + self.id_map = {} + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + def persistent_id(self, obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, SavingProxyForStorage): + return obj.storage_info + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in self.storage_dtypes: + if storage_dtype != self.storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + self.storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = self.id_map.get(storage._cdata) + if storage_key is None: + storage_key = self.saver._write_storage_and_return_key(storage) + self.id_map[storage._cdata] = storage_key + location = torch.serialization.location_tag(storage) + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + +class incremental_save: + def __init__(self, name): + self.name = name + self.zipfile = torch._C.PyTorchFileWriter(str(name)) + self.has_saved = False + self.next_key = 0 + + def __enter__(self): + return self + + def store_early(self, tensor): + if isinstance(tensor, torch.Tensor): + return SavingProxyForTensor(tensor, self) + raise TypeError(f"can only store tensors early, not {type(tensor)}") + + def save(self, obj): + if self.has_saved: + raise RuntimeError("have already saved") + # Write the pickle data for `obj` + data_buf = BytesIO() + pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) + pickler.dump(obj) + data_value = data_buf.getvalue() + self.zipfile.write_record("data.pkl", data_value, len(data_value)) + self.has_saved = True + + def _write_storage_and_return_key(self, storage): + if self.has_saved: + raise RuntimeError("have already saved") + key = self.next_key + self.next_key += 1 + name = f"data/{key}" + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zipfile.write_record(name, storage.data_ptr(), num_bytes) + return key + + def __exit__(self, type, value, traceback): + self.zipfile.write_end_of_file() + + +T = TypeVar("T") + + +def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: + logger = cls(*args, **kwargs) + + def merge_by(dicts, key): + from collections import defaultdict + + out = defaultdict(dict) + for d in dicts: + if key in d: + out[d[key]].update(d) + return [v for _, v in sorted(out.items())] + + def save(self) -> None: + """Overridden to merge CSV by the step number.""" + import csv + + if not self.metrics: + return + metrics = merge_by(self.metrics, "step") + keys = sorted({k for m in metrics for k in m}) + with self._fs.open(self.metrics_file_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(metrics) + + logger.experiment.save = MethodType(save, logger.experiment) + + return logger + + +def chunked_cross_entropy( + logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 +) -> torch.Tensor: + # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate + # the memory usage in fine-tuning settings with low number of parameters. + # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing + # the memory spike's magnitude + + # lm_head was chunked (we are fine-tuning) + if isinstance(logits, list): + # don't want to chunk cross entropy + if chunk_size == 0: + logits = torch.cat(logits, dim=1) + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) + + # chunk cross entropy + logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] + target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] + loss_chunks = [ + torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + return torch.cat(loss_chunks).mean() + + # no chunking at all + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + if chunk_size == 0: + return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) + + # lm_head wasn't chunked, chunk cross entropy + logit_chunks = logits.split(chunk_size) + target_chunks = targets.split(chunk_size) + loss_chunks = [ + torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + return torch.cat(loss_chunks).mean() + + +def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: + for checkpoint_name, attribute_name in mapping.items(): + full_checkpoint_name = prefix + checkpoint_name + if full_checkpoint_name in state_dict: + full_attribute_name = prefix + attribute_name + state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) + return state_dict + + +def get_default_supported_precision(training: bool, tpu: bool = False) -> str: + """Return default precision that is supported by the hardware. + + Args: + training: `-mixed` or `-true` version of the precision to use + tpu: whether TPU device is used + + Returns: + default precision that is suitable for the task and is supported by the hardware + """ + if tpu: + return "32-true" + if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): + return "bf16-mixed" if training else "bf16-true" + return "16-mixed" if training else "16-true" diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py new file mode 100644 index 0000000..f01ef1c --- /dev/null +++ b/pretrain/tinyllama.py @@ -0,0 +1,395 @@ +import glob +import math +import sys +import time +from pathlib import Path +from typing import Optional, Tuple, Union +import math +import lightning as L +import torch +from lightning.fabric.strategies import FSDPStrategy, XLAStrategy +from torch.utils.data import DataLoader +from functools import partial +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) +# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually +from lit_gpt.model import GPT, Block, Config, CausalSelfAttention +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor +from lit_gpt.speed_monitor import estimate_flops, measure_flops +from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load +from pytorch_lightning.loggers import WandbLogger +from lit_gpt import FusedCrossEntropyLoss +import random + +model_name = "tiny_LLaMA_1b" +name = "tinyllama_1b" +out_dir = Path("out") / name + +# Hyperparameters +num_of_devices = 8 +global_batch_size = 512 +learning_rate = 4e-4 +micro_batch_size = 8 +max_step = 715256 * 2 +warmup_steps = 2000 +log_step_interval = 10 +eval_iters = 100 +save_step_interval = 5000 +eval_step_interval = 5000 + + +weight_decay = 1e-1 +beta1 = 0.9 +beta2 = 0.95 +grad_clip = 1.0 +decay_lr = True +min_lr = 4e-4 + +batch_size = global_batch_size // num_of_devices +gradient_accumulation_steps = batch_size // micro_batch_size +assert gradient_accumulation_steps > 0 +warmup_iters = warmup_steps * gradient_accumulation_steps + + + + +max_iters = max_step * gradient_accumulation_steps +lr_decay_iters = max_iters +log_iter_interval = log_step_interval * gradient_accumulation_steps + + +# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. +train_data_config = [ + ("train_slim", 0.693584), + ("train_star", 0.306416), +] + +val_data_config = [ + ("validation", 1.0), +] + +hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} +logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) +wandb_logger = WandbLogger() + + +def setup( + devices: int = 8, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + precision: Optional[str] = None, + tpu: bool = False, + resume: Union[bool, Path] = False, +) -> None: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy=None, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) + fabric.print(hparams) + #fabric.launch(main, train_data_dir, val_data_dir, resume) + main(fabric, train_data_dir, val_data_dir, resume) + + +def main(fabric, train_data_dir, val_data_dir, resume): + monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) + + if fabric.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + + train_dataloader, val_dataloader = create_dataloaders( + batch_size=micro_batch_size, + block_size=config.block_size, + fabric=fabric, + train_data_dir=train_data_dir, + val_data_dir=val_data_dir, + seed=(3407 + fabric.global_rank), + ) + if val_dataloader is None: + train_dataloader = fabric.setup_dataloaders(train_dataloader) + else: + train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) + + fabric.seed_everything(3407) # same seed for every process to init model (FSDP) + + fabric.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + model.apply(partial(model._init_weights ,n_layer=config.n_layer)) + + + fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + fabric.print(f"Total parameters {num_parameters(model):,}") + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False + ) + # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) + optimizer = fabric.setup_optimizers(optimizer) + + state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} + + if resume is True: + resume = sorted(out_dir.glob("*.pth"))[-1] + if resume : + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + + train_time = time.perf_counter() + train(fabric, state, train_dataloader, val_dataloader, monitor, resume) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + + +def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): + model = state["model"] + optimizer = state["optimizer"] + + if val_dataloader is not None: + validate(fabric, model, val_dataloader) # sanity check + + with torch.device("meta"): + meta_model = GPT(model.config) + # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. + # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, + # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead + estimated_flops = estimate_flops(meta_model) * micro_batch_size + fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") + x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) + # measured_flos run in meta. Will trigger fusedRMSNorm error + #measured_flops = measure_flops(meta_model, x) + #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") + del meta_model, x + + total_lengths = 0 + total_t0 = time.perf_counter() + + if fabric.device.type == "xla": + import torch_xla.core.xla_model as xm + + xm.mark_step() + + + initial_iter = state["iter_num"] + curr_iter = 0 + + loss_func = FusedCrossEntropyLoss() + for train_data in train_dataloader: + # resume loader state. This is not elegant but it works. Should rewrite it in the future. + if resume: + if curr_iter < initial_iter: + curr_iter += 1 + continue + else: + resume = False + curr_iter = -1 + fabric.barrier() + fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) + if state["iter_num"] >= max_iters: + break + + # determine and set the learning rate for this iteration + lr = get_lr(state["iter_num"]) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + iter_t0 = time.perf_counter() + + input_ids = train_data[:, 0 : model.config.block_size].contiguous() + targets = train_data[:, 1 : model.config.block_size + 1].contiguous() + is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + loss = loss_func(logits, targets) + # loss = chunked_cross_entropy(logits, targets, chunk_size=0) + fabric.backward(loss / gradient_accumulation_steps) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=grad_clip) + optimizer.step() + optimizer.zero_grad() + state["step_count"] += 1 + elif fabric.device.type == "xla": + xm.mark_step() + state["iter_num"] += 1 + # input_id: B L + total_lengths += input_ids.size(1) + t1 = time.perf_counter() + fabric.print( + f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" + f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " + # print days as well + f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " + ) + + monitor.on_train_batch_end( + state["iter_num"] * micro_batch_size, + t1 - total_t0, + # this assumes that device FLOPs are the same and that all devices have the same batch size + fabric.world_size, + flops_per_batch=estimated_flops, + lengths=total_lengths, + train_loss = loss.item() + ) + + + + + if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: + + t0 = time.perf_counter() + val_loss = validate(fabric, model, val_dataloader) + t1 = time.perf_counter() - t0 + monitor.eval_end(t1) + fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") + fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) + fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"]) + fabric.barrier() + if not is_accumulating and state["step_count"] % save_step_interval == 0: + checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +@torch.no_grad() +def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + + losses = torch.zeros(eval_iters, device=fabric.device) + for k, val_data in enumerate(val_dataloader): + if k >= eval_iters: + break + input_ids = val_data[:, 0 : model.config.block_size].contiguous() + targets = val_data[:, 1 : model.config.block_size + 1].contiguous() + logits = model(input_ids) + loss = chunked_cross_entropy(logits, targets, chunk_size=0) + + # loss_func = FusedCrossEntropyLoss() + # loss = loss_func(logits, targets) + losses[k] = loss.item() + + out = losses.mean() + + model.train() + return out + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" +) -> DataLoader: + datasets = [] + data_config = train_data_config if split == "train" else val_data_config + for prefix, _ in data_config: + filenames = glob.glob(str(data_dir / f"{prefix}*")) + random.seed(seed) + random.shuffle(filenames) + + dataset = PackedDataset( + filenames, + # n_chunks control the buffer size. + # Note that the buffer size also impacts the random shuffle + # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) + n_chunks=8, + block_size=block_size, + shuffle=shuffle, + seed=seed, + num_processes=fabric.world_size, + process_rank=fabric.global_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + fabric, + train_data_dir: Path = Path("data/redpajama_sample"), + val_data_dir: Optional[Path] = None, + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + split="train" + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + fabric=fabric, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + split="validation" + ) + if val_data_dir + else None + ) + return train_dataloader, val_dataloader + + +# learning rate decay scheduler (cosine with warmup) +def get_lr(it): + # 1) linear warmup for warmup_iters steps + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return min_lr + coeff * (learning_rate - min_lr) + + +if __name__ == "__main__": + # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" + # torch.backends.cuda.enable_flash_sdp(False) + torch.set_float32_matmul_precision("high") + + from jsonargparse import CLI + + CLI(setup) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cf1edab --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +# torch>=2.1.0dev +lightning @ git+https://github.com/Lightning-AI/lightning@master +jsonargparse[signatures] # CLI +pandas +pyarrow +tokenizers +sentencepiece +wandb +zstd +# other optional dependencies are +# sentencepiece # pythia, falcon, redpajama +# tokenizers # llama-based models +# bitsandbytes>=0.41.1 # quantize/bnb.py +# scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released +# datasets # quantize/gptq.py +# zstandard # scripts/prepare_redpajama.py +# git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py new file mode 100644 index 0000000..09b6ed8 --- /dev/null +++ b/scripts/convert_lit_checkpoint.py @@ -0,0 +1,264 @@ +import contextlib +import gc +import sys +from functools import partial +from pathlib import Path +from typing import Dict, Literal, Optional, Tuple, Union + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load +# from scripts.convert_hf_checkpoint import layer_template, load_param + + +def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: + split = layer_name.split(".") + number = int(split[idx]) + split[idx] = "{}" + from_name = ".".join(split) + return from_name, number + + +def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: + if hasattr(param, "_load_tensor"): + # support tensors loaded via `lazy_load()` + print(f"Loading {name!r} into RAM") + param = param._load_tensor() + if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: + print(f"Converting {name!r} from {param.dtype} to {dtype}") + param = param.to(dtype) + return param +def copy_weights_falcon( + size: Literal["7b", "40b"], + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +): + weight_map = { + "transformer.wte.weight": "transformer.word_embeddings.weight", + "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", + "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "transformer.ln_f.bias", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + # the original model definition is different for each size + if size == "7b": + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", + } + ) + elif size == "40b": + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", + "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", + "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", + } + ) + else: + raise NotImplementedError + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_gpt_neox( + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "gpt_neox.embed_in.weight", + "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", + "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", + "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", + "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", + "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", + "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", + "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", + "lm_head.weight": "embed_out.weight", + } + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_llama( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +): + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", + "transformer.ln_f.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", + } + for name, param in lit_weights.items(): + if name.endswith(".attn.attn.weight"): + from_name, number = layer_template(name, 2) + q = "model.layers.{}.self_attn.q_proj.weight".format(number) + k = "model.layers.{}.self_attn.k_proj.weight".format(number) + v = "model.layers.{}.self_attn.v_proj.weight".format(number) + qkv = load_param(param, name, None) + qp, kp, vp = tensor_split(qkv, config) + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + elif "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name] + + if to_name is None: + continue + to_name = to_name.format(number) + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def tensor_split( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def kstart(start, blen, klen) -> int: + """returns start index of keys in batch""" + return start + (blen - (klen * 2)) + + def vstart(start, blen, klen) -> int: + """returns start index of values in batch""" + return start + blen - klen + + def vend(start, blen) -> int: + """returns last index of values in batch""" + return start + blen + + # num observations + nobs = param.shape[0] + # batch length + blen = nobs // config.n_query_groups + # key length in batch + klen = config.head_size + # value length in batch + vlen = config.head_size + # the starting index of each new batch + starts = range(0, nobs, blen) + # the indices to splice on + splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] + + qc = () + kc = () + vc = () + + for splice in splices: + qs, ks, vs, ve = splice + qc += (param[qs:ks, :],) + kc += (param[ks:vs, :],) + vc += (param[vs:ve, :],) + + q = torch.cat(qc) + k = torch.cat(kc) + v = torch.cat(vc) + + return q, k, v + + +def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return lit_weights.get("model", lit_weights) + + +def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: + weight_names = {wk.split(".")[-1] for wk in lit_weights} + # LoRA or QLoRA + if any("lora" in wn for wn in weight_names): + raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") + # adapter v2. adapter_bias will only be in adapter_v2 + elif "adapter_bias" in weight_names: + raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") + # adapter. gating_factor is in adapter and adapter_v2 + elif "gating_factor" in weight_names: + raise NotImplementedError("Converting models finetuned with adapter not yet supported.") + + +@torch.inference_mode() +def convert_lit_checkpoint(*, checkpoint_name: str, out_dir: Path, model_name: str) -> None: + config = Config.from_name(model_name) + + if "falcon" in model_name: + copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") + elif config._mlp_class == "LLaMAMLP": + copy_fn = partial(copy_weights_llama, config) + else: + copy_fn = copy_weights_gpt_neox + + # initialize a new empty state dict to hold our new weights + sd = {} + + # checkpoint_name cannot be hardcoded because there exists different outputs such as + # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") + pth_file = out_dir / checkpoint_name + bin_file = pth_file.with_suffix(".bin") + + with incremental_save(bin_file) as saver: + with contextlib.ExitStack() as stack: + lit_weights = stack.enter_context(lazy_load(pth_file)) + lit_weights = maybe_unwrap_state_dict(lit_weights) + check_conversion_supported(lit_weights) + # Incremental save will trigger error + copy_fn(sd, lit_weights, saver=None) + gc.collect() + saver.save(sd) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(convert_lit_checkpoint, as_positional=False) diff --git a/scripts/prepare_redpajama.py b/scripts/prepare_redpajama.py new file mode 100644 index 0000000..2b56726 --- /dev/null +++ b/scripts/prepare_redpajama.py @@ -0,0 +1,166 @@ +import glob +import json +import os +import sys +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Config, Tokenizer + +filenames_sample = [ + "arxiv_sample.jsonl", + "book_sample.jsonl", + "c4_sample.jsonl", + "cc_2019-30_sample.jsonl", + "cc_2020-05_sample.jsonl", + "cc_2021-04_sample.jsonl", + "cc_2022-05_sample.jsonl", + "cc_2023-06_sample.jsonl", + "github_sample.jsonl", + "stackexchange_sample.jsonl", + "wikipedia_sample.jsonl", +] + +filename_sets = { + "arxiv": "arxiv/arxiv*", + "book": "book/book*", + "c4": "c4/c4-train*", + "common_crawl": "common_crawl/*", + "github": "github/filtered*", + "stackexchange": "stackexchange/stackexchange*", + "wikipedia": "wikipedia/wiki*", +} + + +def prepare_sample( + source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for name in filenames_sample: + if match and match not in name: + continue + + filepath = source_path / name + + if not filepath.is_file(): + raise RuntimeError( + f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + prefix, _ = os.path.splitext(name) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=prefix, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + print(f"Processing {name}") + + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare_full( + source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for set_name, pattern in filename_sets.items(): + if match and match not in set_name: + continue + + is_cc = set_name == "common_crawl" + + filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) + + if not filenames: + raise RuntimeError( + f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=set_name, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for name in filenames: + filepath = source_path / name + + print(f"Processing {name}") + + if is_cc: + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + else: + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + destination_path: Path = Path("data/redpajama_sample"), + sample: bool = True, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" + with open(checkpoint_dir / "lit_config.json") as fp: + config = Config(**json.load(fp)) + + prepare_fn = prepare_sample if sample else prepare_full + prepare_fn( + source_path=source_path, + checkpoint_dir=checkpoint_dir, + destination_path=destination_path, + chunk_size=(config.block_size + 1) * 1024, # block size + 1 for causal, 1024 blocks + match=match, + ) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_slimpajama.py b/scripts/prepare_slimpajama.py new file mode 100644 index 0000000..24ec050 --- /dev/null +++ b/scripts/prepare_slimpajama.py @@ -0,0 +1,105 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +# Filename for SlimPajama +slimpajama_sets = { + "train": "train/chunk*/*", + "validation": "validation/chunk*/*", + "test": "test/chunk*/*", +} + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": + continue # we don't want to include the github data + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + + filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + + num_processes = cpu_count() + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file diff --git a/scripts/prepare_starcoder.py b/scripts/prepare_starcoder.py new file mode 100644 index 0000000..838a29f --- /dev/null +++ b/scripts/prepare_starcoder.py @@ -0,0 +1,100 @@ +import json +import glob +import os +from pathlib import Path +import sys +from typing import List +import numpy as np +from tqdm import tqdm +from multiprocessing import Process, cpu_count + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Tokenizer + +import pandas as pd + + +def prepare_full( + source_path: Path, + tokenizer_path: Path, + destination_path: Path, + chunk_size: int, + split: str="train", + filenames_subset: List[str] = None, + process_id: int = 0 +) -> None: + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(tokenizer_path) + + # Use the provided filenames_subset or default to all filenames + filenames = filenames_subset + + if not filenames: + raise RuntimeError( + f"No files matching found at {source_path}. \n" + "Make sure you download the data..." + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders + chunk_size=chunk_size, + sep_token=tokenizer.bos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for filepath in filenames: + print(f"Processing {filepath}") + try: + contents = pd.read_parquet(filepath, engine='pyarrow')['content'] + except: + print(f"Error reading {filepath}!!") + continue + for text in contents: + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + destination_path: Path = Path("data/red_pajama_sample"), + chunk_size: int = 2049 * 1024, + split: str="train", + percentage: float = 1.0, +) -> None: + import time + assert split == "train" # starcoder only has train data + filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) + filenames = filenames[:int(len(filenames) * percentage)] + num_processes = 32 + chunked_filenames = np.array_split(filenames, num_processes) + + processes = [] + start_time = time.time() + + for i, subset in enumerate(chunked_filenames): + p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) + processes.append(p) + p.start() + + for p in processes: + p.join() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + from jsonargparse import CLI + CLI(prepare) \ No newline at end of file