Skip to content

Commit

Permalink
hotfix: revert torch.library register (#709)
Browse files Browse the repository at this point in the history
We observe performance degradation for small operations in flashinfer
v0.2 because of the overhead of `torch.library.custom_op` introduced in
#554.
This PR disables torch custom operator registrations first, we can add
them back with lightweight registration later:
https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674

cc @zhyncs @abcdabcd987 @youkaichao
  • Loading branch information
yzh119 authored Dec 30, 2024
1 parent 4ba91c0 commit ccd3be9
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,24 @@ def register_custom_op(
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
return torch.library.custom_op(
name,
fn,
mutates_args=mutates_args,
device_types=device_types,
schema=schema,
)
# NOTE(Zihao): torch.library.custom_op has significant overhead as mentioned in the following link
# https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674

# return torch.library.custom_op(
# name,
# fn,
# mutates_args=mutates_args,
# device_types=device_types,
# schema=schema,
# )
return lambda x: x

def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
return torch.library.register_fake(name, fn)
# return torch.library.register_fake(name, fn)
return lambda x: x


def get_cuda_stream(device: torch.device) -> int:
Expand Down

0 comments on commit ccd3be9

Please sign in to comment.