Skip to content

Commit

Permalink
Merge branch 'sok_examples' into 'main'
Browse files Browse the repository at this point in the history
Sok documentation

See merge request dl/hugectr/hugectr!1536
  • Loading branch information
minseokl committed Jun 14, 2024
2 parents abd8448 + 57edc38 commit 7ec7bee
Show file tree
Hide file tree
Showing 107 changed files with 4,987 additions and 8,646 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[![Version](https://img.shields.io/github/v/release/NVIDIA-Merlin/HugeCTR?color=orange)](release_notes.md/)
[![LICENSE](https://img.shields.io/github/license/NVIDIA-Merlin/HugeCTR)](https://github.com/NVIDIA-Merlin/HugeCTR/blob/main/LICENSE)
[![Documentation](https://img.shields.io/badge/documentation-blue.svg)](https://nvidia-merlin.github.io/HugeCTR/main/hugectr_user_guide.html)
[![SOK Documentation](https://img.shields.io/badge/SOK%20Documentation-blue?logoColor=blue)](https://nvidia-merlin.github.io/HugeCTR/sparse_operation_kit/master/index.html)

HugeCTR is a GPU-accelerated recommender framework designed for training and inference of large deep learning models.

Expand Down
479 changes: 479 additions & 0 deletions hps_tf/notebooks/sok_train_demo.ipynb

Large diffs are not rendered by default.

26 changes: 21 additions & 5 deletions sparse_operation_kit/ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,31 @@ In sparse training / inference scenarios, for instance, CTR estimation, there ar

SOK provides broad MP functionality to fully utilize all available GPUs, regardless of whether these GPUs are located in a single machine or multiple machines. Simultaneously, SOK takes advantage of existing data-parallel (DP) capabilities of DL frameworks to accelerate training while minimizing code changes. With SOK embedding layers, you can build a DNN model with mixed MP and DP. MP is used to shard large embedding parameter tables, such that they are distributed among the available GPUs to balance the workload, while DP is used for layers that only consume little GPU resources.

SOK provides multiple types of MP embedding layers, optimized for different application scenarios. These embedding layers can leverage all available GPU memory in your cluster to store/retrieve embedding parameters. As a result, all utilized GPUs work synchronously.

SOK is compatible with DP training provided by common synchronized training frameworks, such as [Horovod](https://horovod.ai). Because the input data fed to these embedding layers can take advantage of DP, additional DP from/to MP transformations are needed when SOK is used to scale up your DNN model from single GPU to multiple GPUs. The following picture illustrates the workflow of these embedding layers.
![WorkFlowOfEmbeddingLayer](documents/source/images/workflow_of_embeddinglayer.png)

**Dynamic Embedding Backend** <br>
SOK provides a dynamic embedding table with hash functionality. The features of the dynamic embedding table are as follows:
- The memory capacity of the embedding table dynamically grows during training and automatically evicts entries when the limit is reached.
- dynamic embedding table can accepts pre-hashed indices.

To achieve these two functions, SOK uses [HierarchicalKV](https://github.com/NVIDIA-Merlin/HierarchicalKV) as the backend.

HierarchicalKV(Abbreviated as HKV) is a part of NVIDIA Merlin and provides hierarchical key-value storage to meet RecSys requirements.For more detailed information about HKV, you can read documents the [HKV repo](https://github.com/NVIDIA-Merlin/HierarchicalKV). Here, SOK lists some key points of HKV for recommendation systems:
1. HKV supports lookup with pre-hashed indices.
2. HKV supports setting a capacity limit for the embedding table. When the embedding table reaches its capacity, it can evict embedding vectors using strategies such as LRU and LFU.
3. HKV can configure the location of the embedding vectors to either GPU global memory or Host memory, take full advantage of the system's memory.

With these features of HKV, SOK can easily support more flexible lookup of indices, while also expanding the capacity of the embedding table to a large size.

How to use dynamic embedding in SOK, referencing [SOK DynamicVariable](https://nvidia-merlin.github.io/HugeCTR/sparse_operation_kit/master/get_started/get_started.html#sok-dynamicvariable)

How to use dynamic embedding training DLRM, referencing [SOK Notebooks](https://github.com/NVIDIA-Merlin/HugeCTR/tree/main/sparse_operation_kit/notebooks)

## Installation ##
There are several ways to install this package. <br>

### Obtaining SOK and HugeCTR via Docker ###
### Obtaining SOK And HugeCTR Via Docker ###
This is the quickest way to get started with SOK.
We provide containers with pre-compiled binaries of the latest HugeCTR and SOK versions(also can manually install SOK into `nvcr.io/nvidia/tensorflow series` images).
To get started quickly with container on your machine, run the following command:
Expand All @@ -37,13 +53,13 @@ You can import the library as shown in the following code block:
import sparse_operation_kit as sok
```

### Installing SOK via pip ###
### Installing SOK Via PIP ###
You can install SOK using the following command:
```bash
pip install sparse_operation_kit --no-build-isolation
```

### Installing SOK from source ###
### Installing SOK From Source ###
You can also build the SOK module from source code. Here are the steps to follow: <br>
+ **Download the source code**
```shell
Expand Down
124 changes: 124 additions & 0 deletions sparse_operation_kit/SOK_DLRM_Benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Benchmark DLRM DCNV2 using TF + SOK + HKV

We need several steps to run the benchmark.

## Environment

1. Select a docker image from Merlin TensorFlow
```bash
docker pull nvcr.io/nvidia/merlin/merlin-tensorflow:nightly
```

2. Launch the Merlin TensorFlow container with the following command:
```bash
docker run --runtime=nvidia --rm -it -p 8888:8888 -p 8797:8786 --ipc=host --cap-add SYS_NICE nvcr.io/nvidia/merlin/merlin-tensorflow:nightly
```

3. Install the SOK+HKV from source code:
```bash
git clone ssh://[email protected]:12051/dl/hugectr.git
cd hugectr
git submodule init && git submodule update
cd sparse_operation_kit
mkdir build && cd build
cmake -DSM={your SM version} ..
make -j && make install
rm -rf /usr/local/lib/python3.10/dist-packages/merlin_sok-1.x-py3.10-linux-x86_64.egg
cp -r ../sparse_operation_kit /usr/local/lib/python3.10/dist-packages/
```
## How to Prepare Dataset
Please generate training data according to [the DLRM DCNV2 documentation](https://github.com/mlcommons/training_results_v3.1/tree/main/NVIDIA/benchmarks/dlrm_dcnv2/implementations/hugectr#prepare-the-input-dataset).

## Benchmark

1. Go to the work directory:
```bash
cd documents/tutorials/DLRM_Benchmark
```

2. Prepare Criteo Terabyte dataset

```bash
# train_data.bin and test_data.bin is the binary dataset generated by hugectr
# {splited_dataset} is the target directory to save the dataset
python3 ./preprocess/split_bin.py /path/to/train_data.bin splited_dataset/train --slot_size_array="[39884406,39043,17289,7420,20263,3,7120,1543,63,38532951,2953546,403346,10,2208,11938,155,4,976,14,39979771,25641295,39664984,585935,12972,108,36]"
python3 ./preprocess/split_bin.py /path/to/test_data.bin splited_dataset/test --slot_size_array="[39884406,39043,17289,7420,20263,3,7120,1543,63,38532951,2953546,403346,10,2208,11938,155,4,976,14,39979771,25641295,39664984,585935,12972,108,36]"
```


3. Run the benchmark:

Typically one GPU will be allocated per process, so if a server has 4 GPUs, you will run 4 processes. In `horovodrun`, the number of processes is specified with the `-np` flag.

```bash
# batch size = 65536
horovodrun -np ${num_gpus} ./hvd_wrapper.sh python3 main.py --data_dir=./splited_dataset/ --global_batch=65536 --epochs=100 --lr=24
# batch size = 32768
horovodrun -np ${num_gpus} ./hvd_wrapper.sh python3 main.py --data_dir=./splited_dataset/ --global_batch=32768 --epochs=100 --lr=24
```

## Details about Customized tests

### 1. Initialize the HKV

There are three key options when want to create a HKV instance:

- `init_capacity`: The maximum number of KV pairs that HKV can hold when it is first created.
- `max_capacity`: The maximum number of KV pairs that can be held when HKV is stable (until the last of training process). In the training process, if the load factor greater than a threshold, HKV's capacity will be doubled, but it will not exceed `max_capacity`.
- `max_hbm_for_vectors`: The maximum size of HBM which can be used for HKV to store values (vectors, embeddings). However, HKV will not occupy them all at once. Instead, it will apply these resources when it needs, but please be sure that the system can satisfy its requirement, or the program will crash.
### 2. Optimizer
We can also change the optimizer use `--optimizer_name`,now support `sgd`, `adamax`, `adagrad`, `adadelta`, `ftrl`
## DynamicVariable Configuration
### 3.1 Default behavior
When we choose HKV as the backend of SOK, `DynamicVariable` should be initialized in this way:
```python
self._sok_embedding = sok.DynamicVariable(
var_type="hybrid",
dimension=self._embedding_vec_size, # 128 in Criteo Terabyte Dataset
)
```
By default, the init_capacity and max_capacity of HKV will both be set to 64 * 1024 * 1024, and the max_hbm_for_vectors is 16GB.
### 3.2 Customize
We can also customize the configuration of HKV:
```python
self._sok_embedding = sok.DynamicVariable(
var_type="hybrid",
dimension=self._embedding_vec_size, # 128 in Criteo Terabyte Dataset
init_capacity = 1024 * 1024,
max_capacity = 1024 * 1024,
max_hbm_for_vectors=30, # unit:GB
)
```
Be careful to set the `max_hbm_for_vectors`, and there are three factors that affect the setting of this value:
- Total HBM size.
- Type of optimizer.
- Batch size.
These factors will limit the HBM memory resource which is available to HKV. If not appropriate, the program will be at risk of Out Of Memory.
By the way, HKV will not consume more resources than it needs. For example, it will only consume `max_capacity * dimension * elementSize` to store embeddings when `max_capacity * dimension * elementSize` is less than `x GB` which `x` equals to `max_hbm_for_vectors`.
| batch size \ optimizer | SGD | Adamax | Adagrad | Addelta | Ftrl |
| --- | --- | --- | --- | --- | --- |
| 32768 | 60G | 20G | 35G | 20G | 20G |
| 65536 | 60G | 20G | 20G | 20G | 20G |
| 131072 | 60G | 20G | 20G | 20G | 10G |
| 262144 | 60G | 20G | 20G | 20G | 10G |
### Performance on 8 x H100
| batch size | exit criteria | frequent of evaluation | xla | amp | training time (minutes) | evaluating time (minutes) | total time (minutes) | average time of iteration (ms) | throughput(samples/second) |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| 65536 | 1 epoch | at end | yes | yes | no | yes | 8.79 | 0.10 | 4.16M |
| 65536 | 1 epoch | at end | yes | yes | yes | no | 6.72 | 0.09 | 3.45M |
81 changes: 55 additions & 26 deletions ...k/benchmark/dlrm/lookup_sparse/dataset.py → ...eration_kit/SOK_DLRM_Benchmark/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,25 +1,10 @@
"""
Copyright (c) 2022, NVIDIA CORPORATION.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import queue
import concurrent

import numpy as np
import tensorflow as tf
from typing import List, Optional


class BinaryDataset:
Expand All @@ -36,13 +21,17 @@ def __init__(
label_raw_type=np.int32,
dense_raw_type=np.int32,
category_raw_type=np.int32,
hotness_per_table: List[int] = None,
log=True,
):
"""
* batch_size : The batch size of local rank, which means the total batch size of all ranks should be (batch_size * global_size).
* prefetch : If prefetch > 1, it can only be read sequentially, otherwise it will return incorrect samples.
"""
self._check_file(label_bin, dense_bin, category_bin)
self._hotness_per_table = hotness_per_table
self._hotness_per_sample = np.sum(hotness_per_table)
self._num_category = len(hotness_per_table)
self._check_file(label_bin, dense_bin, category_bin, 4, self._hotness_per_sample)

self._batch_size = batch_size
self._drop_last = drop_last
Expand Down Expand Up @@ -104,14 +93,29 @@ def __init__(

self._log = log

def _check_file(self, label_bin, dense_bin, category_bin):
self._row_lengths = []
for i in range(len(self._hotness_per_table)):
self._row_lengths.append(
tf.repeat(self._hotness_per_table[i], repeats=self._batch_size)
)

def _check_file(
self, label_bin, dense_bin, category_bin, label_type_byte, sparse_hotness_per_sample
):
# num_samples represents the actual number of samples in the dataset
num_samples = os.path.getsize(label_bin) // 4

num_samples = os.path.getsize(label_bin) // label_type_byte
if num_samples <= 0:
raise RuntimeError("There must be at least one sample in %s" % label_bin)

if num_samples <= 0:
raise RuntimeError("There must be at least one sample in %s" % label_bin)
# check file size
for file, bytes_per_sample in [[label_bin, 4], [dense_bin, 52], [category_bin, 104]]:
for file, bytes_per_sample in [
[label_bin, 4],
[dense_bin, 52],
[category_bin, 4 * sparse_hotness_per_sample],
]:
file_size = os.path.getsize(file)
if file_size % bytes_per_sample != 0:
raise RuntimeError(
Expand Down Expand Up @@ -160,20 +164,30 @@ def _get(self, idx):
self._batch_size * self._global_rank
)
batch = self._batch_size
row_lengths = self._row_lengths
if batch != self._batch_size:
row_lengths = []
for i in range(len(self._hotness_per_table)):
row_lengths.append(tf.repeat(self._hotness_per_table[i], repeats=batch))

# read the data from binary file
label_raw_data = os.pread(self._label_file, 4 * batch, 4 * sample_offset)
label = np.frombuffer(label_raw_data, dtype=self._label_raw_type).reshape([batch, 1])

dense_raw_data = os.pread(self._dense_file, 52 * batch, 52 * sample_offset)
dense = np.frombuffer(dense_raw_data, dtype=self._dense_raw_type).reshape([batch, 13])

category_raw_data = os.pread(self._category_file, 104 * batch, 104 * sample_offset)
category_raw_data = os.pread(
self._category_file,
self._hotness_per_sample * 4 * batch,
self._hotness_per_sample * 4 * sample_offset,
)
category = np.frombuffer(category_raw_data, dtype=self._category_raw_type).reshape(
[batch, 26]
[batch, self._hotness_per_sample]
)
indices = np.cumsum(self._hotness_per_table)[:-1]

sub_arrays = np.split(category, indices, axis=1)

# convert numpy data to tensorflow data
if (
self._label_raw_type == self._dense_raw_type
and self._label_raw_type == self._category_raw_type
Expand All @@ -182,14 +196,29 @@ def _get(self, idx):
data = tf.convert_to_tensor(data)
label = tf.cast(data[:, 0:1], dtype=tf.float32)
dense = tf.cast(data[:, 1:14], dtype=tf.float32)
category = tf.cast(data[:, 14:40], dtype=tf.int64)
category_np = data[:, 14:]
flat_values = tf.reshape(category_np, [-1])
category_ragged_tensors = []
for i in range(len(sub_arrays)):
flat_values = tf.reshape(sub_arrays[i], [-1])
category_ragged_tensors.append(
tf.RaggedTensor.from_row_lengths(flat_values, row_lengths[i])
)

else:
label = tf.convert_to_tensor(label, dtype=tf.float32)
dense = tf.convert_to_tensor(dense, dtype=tf.float32)
category = tf.convert_to_tensor(category, dtype=tf.int64)

category_ragged_tensors = []
for i in range(len(sub_arrays)):
flat_values = tf.reshape(sub_arrays[i], [-1])
category_ragged_tensors.append(
tf.RaggedTensor.from_row_lengths(flat_values, row_lengths[i])
)

# preprocess
if self._log:
dense = tf.math.log(dense + 3.0)

return (dense, category), label
return (dense, category_ragged_tensors), label
Loading

0 comments on commit 7ec7bee

Please sign in to comment.