diff --git a/.github/workflows/preview_docs.yml b/.github/workflows/preview_docs.yml new file mode 100644 index 00000000..1dc10304 --- /dev/null +++ b/.github/workflows/preview_docs.yml @@ -0,0 +1,19 @@ +# Add a link to preview the documentation on Read the Docs for every pull request. +name: "RTD preview" + +on: + pull_request_target: + types: + - opened + +permissions: + pull-requests: write + +jobs: + documentation-links: + runs-on: ubuntu-latest + steps: + - uses: readthedocs/actions/preview@v1 + with: + project-slug: "readthedocs-preview" + single-version: true \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..b86618e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,59 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Sphinx documentation +docs/_build/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Environments +.venv \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..ef6cd0f1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-ast + - id: check-merge-conflict + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/mwouts/jupytext + rev: v1.15.2 + hooks: + - id: jupytext + files: docs/tutorials/ + args: [--sync] \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..d6220feb --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,15 @@ +version: 2 + +build: + os: ubuntu-lts-latest + tools: + python: "3.12" + +sphinx: + configuration: docs/conf.py + # Note this is set to false for now while the warnings are resolved + fail_on_warning: false + +python: + install: + - requirements: docs/requirements.txt \ No newline at end of file diff --git a/README.md b/README.md index e11f5f39..46e06479 100644 --- a/README.md +++ b/README.md @@ -9,4 +9,3 @@ open source, fast and deterministic. Check out [`tutorials/`](./tutorials) for more information on how to use Grain! - diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 100644 index 00000000..1c65a7a4 --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1,106 @@ +# Contributing to Grain + + + +## Contributing to the Grain project documentation + +### Pre-requisites + +To contribute to the documentation, you will need to set your development environment. + +You can create a virtual environment or conda environment and install the packages in +`docs/requirements.txt`. + +```bash +# Create a virtual environment +python -m venv .venv +# Activate the virtual environment +source .venv/bin/activate +# Install the requirements +pip install -r docs/requirements.txt +``` + +or with conda + +```bash +# Create a conda environment +conda create -n "grain-docs" python=3.12 +# Activate the conda environment +conda activate grain-docs +# Install the requirements +python -m pip install -r docs/requirements.txt +``` + +### Building the documentation locally + +To build the documentation locally, you can run the following command: + +```bash +# Change to the docs/ directory +cd docs +sphinx-build -b html . _build/html +``` + +You can then open the generated HTML files in your browser by opening +`docs/_build/html/index.html`. + +## Documentation via Jupyter notebooks + +The `pygrain` documentation includes Jupyter notebooks that are rendered +directly into the website via the [myst-nb](https://myst-nb.readthedocs.io/) extension. +To ease review and diff of notebooks, we keep markdown versions of the content +synced via [jupytext](https://jupytext.readthedocs.io/). + +Note you will need to install `jupytext` to sync the notebooks with markdown files: + +```bash +# With pip +python -m pip install jupytext + +# With conda +conda install -c conda-forge jupytext +``` + +### Adding a new notebook + +We aim to have one notebook per topic or tutorial covered. +To add a new notebook to the repository, first move the notebook into the appropriate +location in the `docs` directory: + +```bash +mv ~/new-tutorial.ipynb docs/tutorials/new_tutorial.ipynb +``` + +Next, we use `jupytext` to mark the notebook for syncing with Markdown: + +```bash +jupytext --set-formats ipynb,md:myst docs/tutorials/new_tutorial.ipynb +``` + +Finally, we can sync the notebook and markdown source: + +```bash +jupytext --sync docs/tutorials/new_tutorial.ipynb +``` + +To ensure that the new notebook is rendered as part of the site, be sure to add +references to a `toctree` declaration somewhere in the source tree, for example +in `docs/index.md`. You will also need to add references in `docs/conf.py` +to specify whether the notebook should be executed, and to specify which file +sphinx should use when generating the site. + +### Editing an existing notebook + +When editing the text of an existing notebook, it is recommended to edit the +markdown file only, and then automatically sync using `jupytext` via the +`pre-commit` framework, which we use to check in GitHub CI that notebooks are +properly synced. +For example, say you have edited `docs/tutorials/new_tutorial.md`, then +you can do the following: + +```bash +pip install pre-commit +git add docs/tutorials/new_tutorial.* # stage the new changes +pre-commit run # run pre-commit checks on added files +git add docs/tutorials/new_tutorial.* # stage the files updated by pre-commit +git commit -m "Update new tutorial" # commit to the branch \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 3067b191..b494549b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,10 +2,6 @@ -https://github.com/google/grain/tree/main/docs - - - PyGrain is the pure Python backend for Grain, primarily targeted at JAX users. PyGrain is designed to be: @@ -26,19 +22,19 @@ of dependencies when possible. For example, it should not depend on TensorFlow. ## High Level Idea -The PyGrain backend differs from traditional tf.data pipelines. Instead of +The PyGrain backend differs from traditional `tf.data` pipelines. Instead of starting from filenames that need to be shuffled and interleaved to shuffle the data, PyGrain pipeline starts by sampling indices. Indices are globally unique, monotonically increasing values used to track progress of the pipeline (for checkpointing). These indices are then mapped into -record keys in the range [0, len(dataset)]. Doing so enables *global +record keys in the range `[0, len(dataset)]`. Doing so enables *global transformations* to be performed (e.g. global shuffling, mixing, repeating for multiple epochs, sharding across multiple machines) before reading any records. *Local transformations* that map/filter (aka preprocessing) a single example or combine multiple consecutive records happen after reading. -![Difference between typical tf.data pipeline and a PyGrain pipeline](grain_pipeline.svg) +![Difference between typical tf.data pipeline and a PyGrain pipeline](./images/grain_pipeline.svg) Steps in the pipeline: @@ -55,7 +51,7 @@ Steps in the pipeline: ## Training Loop -*PyGrain* has no opinion on how you write your training loop. Instead PyGrain +*PyGrain* has no opinion on how you write your training loop. Instead, PyGrain will return an iterator that implements: * `next(ds_iter)` returns the element as NumPy arrays. @@ -99,4 +95,3 @@ order defined by the user. The first of these transformations needs to be able to process the raw records as read by the data source. The second transformation needs to be able to process the elements produced by the first transformation and so on. - diff --git a/docs/behind_the_scenes.md b/docs/behind_the_scenes.md index 9002b556..76b4a7d4 100644 --- a/docs/behind_the_scenes.md +++ b/docs/behind_the_scenes.md @@ -1,9 +1,5 @@ # Behind the Scenes - - -https://github.com/google/grain/blob/main/docs/behind_the_scenes.md - In this section, we explore the design of PyGrain. The following diagram illustrates the data flow within PyGrain. Parent process (where user creates the `DataLoader` object) is highlighted in blue while the child processes are @@ -13,36 +9,37 @@ highlighted in green. ![Diagram showing PyGrain data flow in the generic case of multiple workers.](images/data_flow_multiple_workers.png "PyGrain DataFlow, multiple workers.") -A. Parent process launches the Feeder thread. The Feeder thread iterates through -the sampler and distributes `RecordMetadata` objects to input queues of the -child processes (each child process has its own dedicated queue). Parent process -also launches `num_workers` child processes. +* A. Parent process launches the Feeder thread. The Feeder thread iterates + through the sampler and distributes `RecordMetadata` objects to input queues + of the child processes (each child process has its own dedicated queue). + Parent process also launches `num_workers` child processes. -B. Each child process reads `RecordMetadata` objects from its respective -input queue. +* B. Each child process reads `RecordMetadata` objects from its respective + input queue. -C. Each child process reads the data record (corresponding to `record_keys` from -the `RecordMetadata` objects) using the data source. +* C. Each child process reads the data record (corresponding to `record_keys` + from the `RecordMetadata` objects) using the data source. -D. Each child process applies the user-provided Operations to the records -it reads. +* D. Each child process applies the user-provided Operations to the records it + reads. -E. Each child process sends the resulting batches via its dedicated output queue -(offloading sending NumPy Arrays/ Tensors to shared memory blocks.) +* E. Each child process sends the resulting batches via its dedicated output + queue (offloading sending NumPy Arrays/ Tensors to shared memory blocks.) -F. The reader thread in the parent process gets the output batches from the -output queues of the child processes (going through the child processes output -queues in a round-robin fashion.) +* F. The reader thread in the parent process gets the output batches from the + output queues of the child processes (going through the child processes + output queues in a round-robin fashion.) -G. The reader thread submits the batches it read to the reader thread pool to -asynchronously post process them (copy data out of shared memory blocks, close -and unlink shared memory blocks.) An [`AsyncResult`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.AsyncResult) -for the computation happening in the reader thread pool is put into the reader -queue (to ensure results ordering.) +* G. The reader thread submits the batches it read to the reader thread pool + to asynchronously post process them (copy data out of shared memory blocks, + close and unlink shared memory blocks.) An + [`AsyncResult`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.AsyncResult) + for the computation happening in the reader thread pool is put into the + reader queue (to ensure results ordering.) -H. When the end user requests the next element, the main thread gets the -`AsyncResult` from the reader queue and waits for the result to be ready. It -then provides the result to the end user. +* H. When the end user requests the next element, the main thread gets the + `AsyncResult` from the reader queue and waits for the result to be ready. It + then provides the result to the end user. Note that the diagram above illustrates the case when `num_workers` is greater than 0. When `num_workers` is 0, there are no child processes launched. Thus the @@ -50,8 +47,8 @@ flow of data becomes as follows: ![Diagram showing PyGrain dataflow when number of workers is 0](images/data_flow_zero_workers.png "PyGrain DataFlow, zero workers.") - ## The Need for Multiprocessing + In CPython, the global interpreter lock, or GIL, is a mutex that protects access to Python objects, preventing multiple threads from executing Python bytecode at once. This is necessary mainly because CPython's memory management is not @@ -64,6 +61,7 @@ the machine. Multiprocessing solves this problem as different processes are able to run on different CPU cores. ## Communication between processes + Each child process has its own memory space and thus by default processes can’t access each other’s objects. Typically, the communication between processes occurs via [multiprocessing queues](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue). @@ -73,6 +71,7 @@ of elements to be buffered and keeping the synchronisation between producer/ consumer in case the queue is full/empty. ### Shared memory + Queues involve serialising elements (via Pickle), sending elements over a connection and deserialising elements at the receiver side. When the data elements are big (e.g. a batch of videos), communication becomes a bottleneck. @@ -118,6 +117,7 @@ memory block, the block is then closed and unlinked. Shared memory opening, closing, and unlinking are all time-consuming, and we make them run asynchronously. ## Determinism + One of the core requirements of PyGrain is determinism. Determinism involves producing the same elements in the same ordering across multiple runs of the pipeline. @@ -140,18 +140,16 @@ As an example, suppose we have a dataset with 8 records, and we apply shuffling to it. The Sampler might produce something like the following (we omit `index` and `rng` for brevity and show only the record keys): -record keys: [5, 2, 0, 4, 6, 1, 7, 3] - -Having 2 process [P0, P1], each will get the following records keys: - -* P0 gets records keys: [5, 0, 6, 7] -* P1 gets records keys: [2, 4, 1, 3] +record keys: `[5, 2, 0, 4, 6, 1, 7, 3]` -P0 and P1 read records with their assigned record keys, apply transformations to -them, and add them to their respective output queues. The parent process goes -through the output queues for the child processes P0 and P1 in a strict -round-robin fashion. Assuming we apply a map operation + a batch operation -(batch_size = 2), the final ordering of the elements would be: +Having 2 processes `[P0, P1]`, each will get the following records keys: -[[5, 0], [2, 4], [6,7 ], [1, 3]] +* `P0` gets records keys: `[5, 0, 6, 7]` +* `P1` gets records keys: `[2, 4, 1, 3]` +`P0` and `P1` read records with their assigned record keys, apply +transformations to them, and add them to their respective output queues. The +parent process goes through the output queues for the child processes `P0` and +`P1` in a strict round-robin fashion. Assuming we apply a map operation + a +batch operation (`batch_size = 2`), the final ordering of the elements would be +`[[5, 0], [2, 4], [6,7 ], [1, 3]]`. diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..c3e8104d --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,98 @@ +"""Configuration file for the Sphinx documentation builder. + +For the full list of built-in configuration values, see the documentation: +https://www.sphinx-doc.org/en/master/usage/configuration.html +""" + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +import pathlib +import sys + +sys.path.insert(0, str(pathlib.Path('..', 'grain').resolve())) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'Grain' +copyright = '2024, Grain team' # pylint: disable=redefined-builtin +author = 'Grain team' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + 'myst_nb', + 'sphinx_copybutton', + 'sphinx_design', + 'autoapi.extension', +] + +templates_path = ['_templates'] +source_suffix = ['.rst', '.ipynb', '.md'] +exclude_patterns = [ + '_build', + 'Thumbs.db', + '.DS_Store', + 'tutorials/dataset_basic_tutorial.md', +] + +# Suppress warning in exception basic_data_tutorial +suppress_warnings = [ + 'misc.highlighting_failure', +] + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'sphinx_book_theme' +html_title = 'PyGrain' +html_static_path = ['_static'] + +# TODO: Add logo and favicon +# html_logo = '_static/' +# html_favicon = '_static/favicon.png' + +# Theme-specific options +# https://sphinx-book-theme.readthedocs.io/en/stable/reference.html +html_theme_options = { + 'show_navbar_depth': 1, + 'show_toc_level': 3, + 'repository_url': 'https://github.com/google/grain', + 'use_issues_button': True, + 'use_repository_button': True, + 'path_to_docs': 'docs/', + 'navigation_with_keys': True, +} + +# Autodoc settings +# Should be relative to the source of the documentation +autoapi_dirs = [ + '../grain/_src/core', + '../grain/_src/python', +] + +autoapi_ignore = [ + '*_test.py', + 'testdata/*', + '*/dataset/stats.py', +] + +# -- Myst configurations ------------------------------------------------- +myst_enable_extensions = ['colon_fence'] +nb_execution_mode = 'force' +nb_execution_allow_errors = False +nb_merge_streams = True +nb_execution_show_tb = True + +# Notebook cell execution timeout; defaults to 30. +nb_execution_timeout = 100 + +# List of patterns, relative to source directory, that match notebook +# files that will not be executed. +nb_execution_excludepatterns = [ + 'tutorials/dataset_basic_tutorial.ipynb', +] diff --git a/docs/data_loader/samplers.md b/docs/data_loader/samplers.md index 5e7508f4..ad74383f 100644 --- a/docs/data_loader/samplers.md +++ b/docs/data_loader/samplers.md @@ -1,9 +1,5 @@ # Samplers - - -https://github.com/google/grain/blob/main/docs/data_loader/samplers.md - Samplers in PyGrain are responsible for determining the order in which records are processed. This allows PyGrain to implement global transformations (e.g. global shuffling, sharding, repeating for multiple epochs) before reading any @@ -38,16 +34,18 @@ class RecordMetadata: ``` ## Index Sampler + This is our recommended Sampler. It supports: -* Sharding across multiple machines (`shard_options` parameter). -* Global shuffle of the data (`shuffle` parameter). -* Repeating records for multiple epochs (`num_epochs` parameter). Note that the -shuffle order changes across epochs. Behind the scenes, this relies on -[tf.random_index_shuffle](https://www.tensorflow.org/api_docs/python/tf/random_index_shuffle). -* Stateless random operations. Each `RecordMetadata` object emitted by the -`IndexSampler` contains an RNG uniquely seeded on a per-record basis. This -RNG can be used for random augmentations while not relying on a global state. +* Sharding across multiple machines (`shard_options` parameter). +* Global shuffle of the data (`shuffle` parameter). +* Repeating records for multiple epochs (`num_epochs` parameter). Note that + the shuffle order changes across epochs. Behind the scenes, this relies on + [tf.random_index_shuffle](https://www.tensorflow.org/api_docs/python/tf/random_index_shuffle). +* Stateless random operations. Each `RecordMetadata` object emitted by the + `IndexSampler` contains an RNG uniquely seeded on a per-record basis. This + RNG can be used for random augmentations while not relying on a global + state. ```python index_sampler = pygrain.IndexSampler( @@ -76,9 +74,9 @@ for record_metadata in index_sampler: PyGrain can accommodate custom user-defined samplers. Users implementing their own sampler should ensure it: -* implements the aforementioned interface. -* is adequately performant. Since PyGrain's -`DataLoader` iterates sequentially through the sampler to distribute indices to -child processes, a slow sampler will become a bottleneck and reduce end-to-end -pipeline performance. As a reference, we recommend sampler iteration performance -of at approx. 50,000 elements / sec for most use cases. +* implements the aforementioned interface. +* is adequately performant. Since PyGrain's `DataLoader` iterates sequentially + through the sampler to distribute indices to child processes, a slow sampler + will become a bottleneck and reduce end-to-end pipeline performance. As a + reference, we recommend sampler iteration performance of at approx. 50,000 + elements / sec for most use cases. diff --git a/docs/data_loader/transformations.md b/docs/data_loader/transformations.md index f30133ff..244baafd 100644 --- a/docs/data_loader/transformations.md +++ b/docs/data_loader/transformations.md @@ -2,10 +2,6 @@ -https://github.com/google/grain/blob/main/docs/data_loader/transformations.md - - - Grain Transforms interface denotes transformations which are applied to data. In the case of local transformations (such as map, random map, filter), the transforms receive an element on which custom changes are applied. For global @@ -17,11 +13,11 @@ The Grain core transforms interface code is ## MapTransform -MapTransform is for 1:1 transformations of elements. Elements can be of any +`MapTransform` is for 1:1 transformations of elements. Elements can be of any type, it is the user's responsibility to use the transformation such that the inputs it receives correspond to the signature. -Example of transformation which implements MapTransform (for elements of type +Example of transformation which implements `MapTransform` (for elements of type `int`): ```python @@ -33,10 +29,11 @@ class PlusOne(transforms.MapTransform): ## RandomMapTransform -RandomMapTransform is for 1:1 random transformations of elements. The interface -requires a `np.random.Generator` as parameter to the 'random_map' function. +`RandomMapTransform` is for 1:1 random transformations of elements. The +interface requires a `np.random.Generator` as parameter to the `random_map` +function. -Example of a RandomMapTransform: +Example of a `RandomMapTransform`: ```python class PlusRandom(transforms.RandomMapTransform): @@ -47,11 +44,11 @@ class PlusRandom(transforms.RandomMapTransform): ## FlatMapTransform -FlatMapTransform is for splitting operations of individual elements. The +`FlatMapTransform` is for splitting operations of individual elements. The `max_fan_out` is the maximum number of splits that an element can generate. Please consult the code for detailed info. -Example of a FlatMapTransform: +Example of a `FlatMapTransform`: ```python class FlatMapTransformExample(transforms.FlatMapTransform): @@ -64,10 +61,10 @@ class FlatMapTransformExample(transforms.FlatMapTransform): ## FilterTransform -FilterTransform is for applying filtering to individual elements. Elements for +`FilterTransform` is for applying filtering to individual elements. Elements for which the filter function returns False will be removed. -Example of a FilterTransform that removes all even elements: +Example of a `FilterTransform` that removes all even elements: ```python class RemoveEvenElements(FilterTransform): @@ -78,8 +75,9 @@ class RemoveEvenElements(FilterTransform): ## Batch -To apply the Batch transform, just pass `grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)`. +To apply the `Batch` transform, pass `grain.Batch(batch_size=batch_size, +drop_remainder=drop_remainder)`. -Note: The batch size used when passing Batch transform will be the global batch -size if it is done before sharding and the *per host* batch size if it is after. -Typically usage with IndexSampler is after sharding. \ No newline at end of file +Note: The batch size used when passing `Batch` transform will be the global +batch size if it is done before sharding and the *per host* batch size if it is +after. Typically usage with `IndexSampler` is after sharding. diff --git a/docs/data_sources.md b/docs/data_sources.md index 85e516a2..ba613f0c 100644 --- a/docs/data_sources.md +++ b/docs/data_sources.md @@ -2,10 +2,6 @@ -https://github.com/google/grain/blob/main/docs/data_sources.md - - - A PyGrain data source is responsible for retrieving individual records. Records could be in a file/storage system or generated on the fly. Data sources need to implement the following protocol: @@ -24,8 +20,7 @@ class RandomAccessDataSource(Protocol, Generic[T]): ## File Format Note that the underlying file format/storage system needs to support efficient -random access. -Grain currently supports random-access file format [ArrayRecord](https://github.com/google/array_record) +random access. Grain currently supports random-access file format [ArrayRecord](https://github.com/google/array_record). ## Available Data Sources @@ -33,10 +28,10 @@ We provide a variety of data sources for PyGrain, which we discuss in the follow ### Range Data Source -This data source mimics the built in python +This data source mimics the built-in Python [range class](https://docs.python.org/3/library/functions.html#func-range). It can be used for initial PyGrain testing or if your use case involves generating -records on the fly (for example if you just want to generate synthetic records +records on the fly (for example if you only want to generate synthetic records online rather than read records from storage.) ```python @@ -77,13 +72,12 @@ File instruction objects enable a few use cases: ### TFDS Data Source TFDS provides PyGrain compatible data sources via `tfds.data_source()`. -Arguments are equivalent to `tfds.load()`. +Arguments are equivalent to `tfds.load()`. For more information see ```python tfds_data_source = tfds.data_source("imagenet2012", split="train[:75%]") ``` -TIP: Make sure to depend on `//tensorflow_datasets:tf_less`. The ## Implement your own Data Source You can implement your own data source and use it with PyGrain. It needs to diff --git a/docs/grain_pipeline.svg b/docs/images/grain_pipeline.svg similarity index 100% rename from docs/grain_pipeline.svg rename to docs/images/grain_pipeline.svg diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..771be6a8 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,69 @@ +# Grain - Feeding JAX Models + + + +Grain is a library for reading data for training and evaluating JAX models. It's +open source, fast and deterministic. + +::::{grid} 1 2 2 3 :gutter: 1 1 1 2 + +:::{grid-item-card} {octicon}`zap;1.5em;sd-mr-1` Powerful Users can bring +arbitrary Python transformations. ::: + +:::{grid-item-card} {octicon}`tools;1.5em;sd-mr-1` Flexible Grain is designed to +be modular. Users can readily override Grain components if need be with their +own implementation. ::: + +:::{grid-item-card} {octicon}`versions;1.5em;sd-mr-1` Deterministic Multiple +runs of the same pipeline will produce the same output. ::: + +:::{grid-item-card} {octicon}`check-circle;1.5em;sd-mr-1` Resilient to +pre-emptions Grain is designed such that checkpoints have minimal size. After +pre-emption, Grain can resume from where it left off and produce the same output +as if it was never pre-empted. ::: + +:::{grid-item-card} {octicon}`sparkles-fill;1.5em;sd-mr-1` Performant We took +care while designing Grain to ensure that it's performant (refer to the +[Behind the Scenes](behind_the_scenes.md) section of the documentation.) We also +tested it against multiple data modalities (e.g.Text/Audio/Images/Videos). ::: + +:::{grid-item-card} {octicon}`package;1.5em;sd-mr-1` With minimal dependencies +Grain minimizes its set of dependencies when possible. For example, it should +not depend on TensorFlow. ::: + +:::: + +``` {toctree} +:maxdepth: 1 +:caption: Getting started +installation +behind_the_scenes +data_sources +``` + +``` {toctree} +:maxdepth: 1 +:caption: Data Loader +data_loader/samplers +data_loader/transformations +``` + +``` {toctree} +:maxdepth: 1 +:caption: Tutorials +tutorials/dataset_basic_tutorial +``` + +``` {toctree} +:maxdepth: 1 +:caption: Contributor guides +CONTRIBUTING +``` + + + +``` {toctree} +:maxdepth: 1 +:caption: References +autoapi/index +``` diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 00000000..f777fd07 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,9 @@ +# Installing Grain + + + +To install Grain, you can use pip: + +```bash +pip install grain +``` diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..6863b502 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,13 @@ +# Sphinx-related requirements. +sphinx +sphinx-book-theme>=1.0.1 +myst-nb +myst-parser[linkify] +sphinx-book-theme +sphinx-copybutton +sphinx-design +# Avoiding an issue with the collapsible sidebar. +pydata-sphinx-theme<0.16.0 +# To generate API documentation. +sphinx-autoapi +sphinx-autodoc2 \ No newline at end of file diff --git a/tutorials/dataset_basic_tutorial.ipynb b/docs/tutorials/dataset_basic_tutorial.ipynb similarity index 100% rename from tutorials/dataset_basic_tutorial.ipynb rename to docs/tutorials/dataset_basic_tutorial.ipynb diff --git a/docs/tutorials/dataset_basic_tutorial.md b/docs/tutorials/dataset_basic_tutorial.md new file mode 100644 index 00000000..3a02162b --- /dev/null +++ b/docs/tutorials/dataset_basic_tutorial.md @@ -0,0 +1,245 @@ +-------------------------------------------------------------------------------- + +jupytext: formats: ipynb,md:myst text_representation: extension: .md +format_name: myst format_version: 0.13 jupytext_version: 1.15.2 kernelspec: +display_name: Python 3 + +## name: python3 + ++++ {"id": "BvnXLPI_2dNJ"} + +# Dataset Basic Tutorial with PyGrain + +Installs PyGrain (OSS only) + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: sHOibn5Q2GRt +outputId: f4c3e5a6-56b8-47f1-c5a1-a25fd0c433b3 +--- +# @test {"output": "ignore"} +!pip install grain +``` + ++++ {"id": "8UuJxi2p3lPp"} + +# Imports + +``` {code-cell} +:id: ZgB5xOru2Zz8 +import grain.python as grain +import pprint +``` + ++++ {"id": "gPv3wrQd3pZS"} + +# `MapDataset` + +`MapDataset` defines a dataset that supports efficient random access. Think of +it as an (infinite) `Sequence` that computes values lazily. It will either be +the starting point of the input pipeline or in the middle of the pipeline +following another `MapDataset`. Grain provides many basic transformations for +users to get started. + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: 3z3Em5jC2iVz +outputId: b3350dec-a6a9-444b-95f8-dc5b6899f82c +--- +dataset = ( + grain.MapDataset.range(10) + .shuffle(seed=10) # Shuffles globally. + .map(lambda x: x+1) # Maps each element. + .batch(batch_size=2) # Batches consecutive elements. +) +pprint.pprint(dataset[0]) +pprint.pprint(list(dataset)) +``` + ++++ {"id": "Aii_JDBw5SEI"} + +The requirement for `MapDataset`'s source is a `grain.RandomAccessDataSource` +interface: i.e. `__getitem__` and `__len__`. + +``` {code-cell} +:id: kCbDSzlS4a-A +# Note: Inheriting `grain.RandomAccessDataSource` is optional but recommended. +class MySource(grain.RandomAccessDataSource): + def __init__(self): + self._data = [0, 1, 2, 3, 4, 5, 6, 7] + def __getitem__(self, idx): + return self._data[idx] + def __len__(self): + return len(self._data) +``` + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: m8Cyn7gt6FYy +outputId: f0ada3bd-5c38-4120-d9d4-e832a76cc3c6 +--- +source = MySource() +dataset = ( + grain.MapDataset.source(source) + .shuffle(seed=10) # Shuffles globally. + .map(lambda x: x+1) # Maps each element. + .batch(batch_size=2) # Batches consecutive elements. +) +pprint.pprint(dataset[0]) +pprint.pprint(list(dataset)) +``` + ++++ {"id": "zKv2kWjB6XPd"} + +Access by index will never raise an `IndexError` and can treat indices that are +equal or larger than the length as a different epoch (e.g. shuffle differently, +use different random numbers). + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: GSW1cJe06NEO +outputId: 547a8993-f835-4bae-bbad-3672666600e4 +--- +# Prints the 3rd element of the second epoch. +pprint.pprint(dataset[len(dataset)+2]) +``` + ++++ {"id": "azfAr8F37njE"} + +Note that `dataset[idx] == dataset[len(dataset) + idx]` iff there's no random +transfomations. Since `dataset` has global shuffle, different epochs are +shuffled differently: + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: _o3wxb8k7XDY +outputId: de5d0c1a-53a7-445c-913a-779d55cb85fe +--- +pprint.pprint(dataset[len(dataset)+2] == dataset[2]) +``` + ++++ {"id": "B2kLX0fa8GfV"} + +You can use `filter` to remove elements not needed but it will return `None` to +indicate that there is no element at the given index. + +Returning `None` for the majority of positions can negatively impact performance +of the pipeline. For example, if your pipeline filters 90% of the data it might +be better to store a filtered version of your dataset. + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: ai4zcltV7sSN +outputId: 5beff138-e194-414a-d219-0e90571828f7 +--- +filtered_dataset = dataset.filter(lambda e: (e[0] + e[1]) % 2 == 0) +pprint.pprint(f"Length of this dataset: {len(filtered_dataset)}") +pprint.pprint([filtered_dataset[i] for i in range(len(filtered_dataset))]) +``` + ++++ {"id": "FJLK_BQj9GuG"} + +`MapDataset` also supports slicing using the same syntax as Python lists. This +returns a `MapDataset` representing the sliced section. Slicing is the easiest +way to "shard" data during distributed training. + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: -fuS_OGS8x5Z +outputId: 7b18ef71-98bc-49ce-d5d6-07af4fdc89f3 +--- +shard_index = 0 +shard_count = 2 +sharded_dataset = dataset[shard_index::shard_count] +print(f"Sharded dataset length = {len(sharded_dataset)}") +pprint.pprint(sharded_dataset[0]) +pprint.pprint(sharded_dataset[1]) +``` + ++++ {"id": "KvycxocM-Fpk"} + +For the actual running training with the dataset, we should convert `MapDataset` +into `IterDataset` to leverage parallel prefetching to hide the latency of each +element's IO using Python threads. + +This brings us to the next section of the tutorial: `IterDataset`. + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: FnWPIpce9aAJ +outputId: 0441c0e8-371f-425a-9660-449eca19eece +--- +iter_dataset = sharded_dataset.to_iter_dataset(grain.ReadOptions(num_threads=16, prefetch_buffer_size=500)) +for element in iter_dataset: + pprint.pprint(element) +``` + ++++ {"id": "W-Brm4Mh_Bo1"} + +# IterDataset + +Most data pipelines will start with one or more `MapDataset` (often derived from +a `RandomAccessDataSource`) and switch to `IterDataset` late or not at all. +`IterDataset` does not support efficient random access and only supports +iterating over it. It's an `Iterable`. + +Any `MapDataset` can be turned into a `IterDataset` by calling +`to_iter_dataset`. When possible this should happen late in the pipeline since +it will restrict the transformations that can come after it (e.g. global shuffle +must come before). This conversion by default skips `None` elements. + ++++ {"id": "GDO1u2tQ_zPz"} + +`DatasetIterator` is a stateful iterator of `IterDataset`. The state of the +iterator can be cheaply saved and restored. This is intended for checkpointing +the input pipeline together with the trained model. The returned state will not +contain data that flows through the pipeline. + +Essentially, `DatasetIterator` only checkpoints index information for it to +recover (assuming the underlying content of files will not change). + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: DRgatGFX_nxL +outputId: 70043ad5-551c-44a0-adba-cb1927f38f6b +--- +dataset_iter = iter(dataset) +pprint.pprint(isinstance(dataset_iter, grain.DatasetIterator)) +``` + +``` {code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: dOCiJfSJ_vi4 +outputId: d71ad80a-2e4d-4367-fe89-c60b8c4b0039 +--- +pprint.pprint(next(dataset_iter)) +checkpoint = dataset_iter.get_state() +pprint.pprint(next(dataset_iter)) +# Recover the iterator to the state after the first produced element. +dataset_iter.set_state(checkpoint) +pprint.pprint(next(dataset_iter)) # This should generate the same element as above +``` + +``` {code-cell} +:id: Fh5iAUPqYQ7g +```