Skip to content

Commit

Permalink
Introduce load_factor - use dtype.int by default in HashMap ben…
Browse files Browse the repository at this point in the history
…chmark (#396)

It is not recommended to use `int64_t` in `CUDAHashMap` since it uses
`atomicCAS` internally, which is not optimized for 8 byte precision at
all.
  • Loading branch information
rusty1s authored Feb 4, 2025
1 parent 352d9d3 commit c6456b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
21 changes: 13 additions & 8 deletions benchmark/classes/hash_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,29 @@
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--dtype', type=str, default='int',
choices=['short', 'int', 'long'])
parser.add_argument('--num_keys', type=int, default=10_000_000)
parser.add_argument('--num_queries', type=int, default=1_000_000)
args = parser.parse_args()

dtype = getattr(torch, args.dtype)

args.num_queries = min(args.num_queries, args.num_keys)

num_warmups, num_steps = 50, 100
if args.device == 'cpu':
num_warmups, num_steps = num_warmups // 10, num_steps // 10

max_value = torch.iinfo(torch.long).max
max_value = torch.iinfo(dtype).max

key1 = torch.randint(0, max_value, (args.num_keys, ), dtype=torch.long,
key1 = torch.randint(0, max_value, (args.num_keys, ), dtype=dtype,
device=args.device)
query1 = key1[torch.randperm(key1.size(0), device=args.device)]
query1 = query1[:args.num_queries]

key2 = torch.randperm(args.num_keys, device=args.device)
query2 = torch.randperm(args.num_queries, device=args.device)
key2 = torch.randperm(args.num_keys, dtype=dtype, device=args.device)
query2 = torch.randperm(args.num_queries, dtype=dtype, device=args.device)
query2 = query2[:args.num_queries]

if key1.is_cpu:
Expand Down Expand Up @@ -60,15 +64,16 @@
for i in range(num_warmups + num_steps):
torch.cuda.synchronize()
t_start = time.perf_counter()
hash_map = torch.full((args.num_keys, ), fill_value=-1,
dtype=torch.long, device=args.device)
hash_map[key2] = torch.arange(args.num_keys, device=args.device)
hash_map = torch.full((args.num_keys, ), fill_value=-1, dtype=dtype,
device=args.device)
hash_map[key2.long()] = torch.arange(args.num_keys, dtype=dtype,
device=args.device)
torch.cuda.synchronize()
if i >= num_warmups:
t_init += time.perf_counter() - t_start

t_start = time.perf_counter()
out2 = hash_map[query2]
out2 = hash_map[query2.long()]
torch.cuda.synchronize()
if i >= num_warmups:
t_get += time.perf_counter() - t_start
Expand Down
10 changes: 5 additions & 5 deletions pyg_lib/csrc/classes/cuda/hash_map.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ struct CUDAHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

CUDAHashMapImpl(const at::Tensor& key) {
CUDAHashMapImpl(const at::Tensor& key, double load_factor) {
KeyType constexpr empty_key_sentinel = std::numeric_limits<KeyType>::min();
ValueType constexpr empty_value_sentinel = -1;

size_t capacity = std::ceil(key.numel() / load_factor);
map_ = std::make_unique<cuco::static_map<KeyType, ValueType>>(
2 * key.numel(), // load_factor = 0.5
cuco::empty_key{empty_key_sentinel},
capacity, cuco::empty_key{empty_key_sentinel},
cuco::empty_value{empty_value_sentinel});

const auto key_data = key.data_ptr<KeyType>();
Expand Down Expand Up @@ -89,15 +89,15 @@ struct CUDAHashMapImpl : HashMapImpl {

struct CUDAHashMap : torch::CustomClassHolder {
public:
CUDAHashMap(const at::Tensor& key) {
CUDAHashMap(const at::Tensor& key, double load_factor = 0.5) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"CUDAHashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CUDA);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

DISPATCH_KEY(key.scalar_type(), "cuda_hash_map_init", [&] {
map_ = std::make_unique<CUDAHashMapImpl<scalar_t>>(key);
map_ = std::make_unique<CUDAHashMapImpl<scalar_t>>(key, load_factor);
});
}

Expand Down

0 comments on commit c6456b7

Please sign in to comment.