Skip to content

Commit

Permalink
use /( in regex match
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 27, 2024
1 parent 857b8db commit 8e3de02
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]):
else:
new_size = (input_size[0] + 1, input_size[1])
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"):
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\(NF4Tensor\) with new size"):
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
Expand All @@ -305,20 +305,20 @@ def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):

def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with step"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end "):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end "):
nf4_tensor[:2]

def test_tensor_slice_2d_invalid(self):
nf4_tensor = to_nf4(torch.randn((512, 512)))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with dim"):
nf4_tensor[:, :511]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end"):
nf4_tensor[:2]

@parametrize("input_size", [(512 * 512,), (512, 512)])
Expand All @@ -334,9 +334,9 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
with self.assertRaisesRegex(NotImplementedError, "aten.view\(NF4Tensor\) with size"):
nf4_tensor.view(input_size)
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
with self.assertRaisesRegex(NotImplementedError, "aten.view\(NF4Tensor\) with size"):
nf4_tensor.view(input_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
Expand All @@ -361,13 +361,13 @@ def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]):
size = (input_size[0] - 1, )
else:
size = (input_size[0] - 1, input_size[1])
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"):
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\(NF4Tensor\) different numel"):
torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset())
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"):
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\(NF4Tensor\) only support original storage offset"):
torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1)

if len(input_size) == 2:
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"):
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\(NF4Tensor\) only support continuous stride"):
stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0])
torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset())

Expand Down

0 comments on commit 8e3de02

Please sign in to comment.