diff --git a/test/test_base.py b/test/test_base.py index b2ee071..2c7c3f4 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -232,58 +232,18 @@ def _test_linear_impl( @pytest.mark.parametrize( "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) + @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("linear_bias", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_linear_nobias( + def test_linear( self, x_shape, emulate: bool, scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, - ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - x = torch.randn(*x_shape, device="cuda") - m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl( - x, - m_ref, - emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, - ) - - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) - @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] - ) - @pytest.mark.parametrize( - "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] - ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_linear_bias( - self, - x_shape, - scaling_type_x: TensorScalingType, - scaling_type_w: TensorScalingType, - scaling_type_dL_dY: TensorScalingType, - emulate: bool, linear_dtype: torch.dtype, + linear_bias: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -295,7 +255,7 @@ def test_linear_bias( ) pytest.skip() x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) self._test_linear_impl( x, m_ref,