Skip to content

Commit

Permalink
Revert "nccl op changes"
Browse files Browse the repository at this point in the history
This reverts commit ee4a9c8.
  • Loading branch information
apbose committed Jan 22, 2025
1 parent ee4a9c8 commit 17ed16f
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 24 deletions.
8 changes: 5 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -17,6 +16,8 @@
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
Expand All @@ -29,7 +30,7 @@ def fused_nccl_gather(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_gather(
return impl.distributed.nccl_gather(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -45,14 +46,15 @@ def fused_nccl_reduce_scatter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.nccl_ops.nccl_reduce_scatter(
return impl.distributed.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,7 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:

if op_target in shape_inference_funcs:
new_shape = shape_inference_funcs[op_target](node)
new_node_dtype = None
if node.meta["val"].dtype == torch.complex64:
new_node_dtype = torch.float32
else:
new_node_dtype = torch.float64
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
node.meta["val"] = fake_mode.from_tensor(real_tensor)
else:
print("No shape for the inference function", {op_name})
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def fuse_distributed_ops(
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
Expand All @@ -57,12 +58,11 @@ def fuse_distributed_ops(
args=(node.args[0], node.args[1], node.args[2]),
)
else:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,15 +364,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

for i, contiguous_input in enumerate(contiguous_inputs):
if contiguous_input.dtype == torch.complex64:
contiguous_input_real = contiguous_input.real
contiguous_input_imag = contiguous_input.imag
contiguous_inputs[i] = torch.stack(
(contiguous_input_real, contiguous_input_imag), dim=-1
)

with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down

0 comments on commit 17ed16f

Please sign in to comment.