Skip to content

Commit

Permalink
Merge pull request rapidsai#4263 from rapidsai/branch-24.04
Browse files Browse the repository at this point in the history
Forward-merge branch-24.04 to branch-24.06
  • Loading branch information
GPUtester authored Mar 20, 2024
2 parents b2a3890 + 4becbe8 commit 9d7b5be
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,10 @@ def forward(
if edge_envelope is not None:
out = out * edge_envelope.view(-1, 1)

out = scatter_reduce(out, dst, dim=0, dim_size=num_dst_nodes, reduce=reduce)
dtype = out.dtype
out = scatter_reduce(
out.float(), dst, dim=0, dim_size=num_dst_nodes, reduce=reduce
).to(dtype)

if self.batch_norm:
out = self.batch_norm(out)
Expand Down

0 comments on commit 9d7b5be

Please sign in to comment.