diff --git a/torchx/specs/named_resources_aws.py b/torchx/specs/named_resources_aws.py index 49456a0c..2fa7c9b0 100644 --- a/torchx/specs/named_resources_aws.py +++ b/torchx/specs/named_resources_aws.py @@ -354,6 +354,46 @@ def aws_trn1_32xlarge() -> Resource: ) +def aws_inf2_xlarge() -> Resource: + return Resource( + cpu=4, + gpu=0, + memMB=16 * GiB, + capabilities={K8S_ITYPE: "inf2.xlarge"}, + devices={NEURON_DEVICE: 1}, + ) + + +def aws_inf2_8xlarge() -> Resource: + return Resource( + cpu=32, + gpu=0, + memMB=128 * GiB, + capabilities={K8S_ITYPE: "inf2.8xlarge"}, + devices={NEURON_DEVICE: 1}, + ) + + +def aws_inf2_24xlarge() -> Resource: + return Resource( + cpu=96, + gpu=0, + memMB=384 * GiB, + capabilities={K8S_ITYPE: "inf2.24xlarge"}, + devices={NEURON_DEVICE: 6}, + ) + + +def aws_inf2_48xlarge() -> Resource: + return Resource( + cpu=192, + gpu=0, + memMB=768 * GiB, + capabilities={K8S_ITYPE: "inf2.48xlarge"}, + devices={NEURON_DEVICE: 12}, + ) + + NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = { "aws_t3.medium": aws_t3_medium, "aws_m5.2xlarge": aws_m5_2xlarge, @@ -390,4 +430,8 @@ def aws_trn1_32xlarge() -> Resource: "aws_g6e.48xlarge": aws_g6e_48xlarge, "aws_trn1.2xlarge": aws_trn1_2xlarge, "aws_trn1.32xlarge": aws_trn1_32xlarge, + "aws_inf2.xlarge": aws_inf2_xlarge, + "aws_inf2.8xlarge": aws_inf2_8xlarge, + "aws_inf2.24xlarge": aws_inf2_24xlarge, + "aws_inf2.48xlarge": aws_inf2_48xlarge } diff --git a/torchx/specs/test/named_resources_aws_test.py b/torchx/specs/test/named_resources_aws_test.py index fcd4526d..556bd35d 100644 --- a/torchx/specs/test/named_resources_aws_test.py +++ b/torchx/specs/test/named_resources_aws_test.py @@ -43,6 +43,10 @@ aws_t3_medium, aws_trn1_2xlarge, aws_trn1_32xlarge, + aws_inf2_xlarge, + aws_inf2_8xlarge, + aws_inf2_24xlarge, + aws_inf2_48xlarge, EFA_DEVICE, GiB, K8S_ITYPE, @@ -232,6 +236,31 @@ def test_aws_trn1(self) -> None: self.assertEqual(trn1_32.memMB, trn1_2.memMB * 16) self.assertEqual({EFA_DEVICE: 8, NEURON_DEVICE: 16}, trn1_32.devices) + def test_aws_inf2(self) -> None: + inf2_xlarge = aws_inf2_xlarge() + self.assertEqual(4, inf2_xlarge.cpu) + self.assertEqual(0, inf2_xlarge.gpu) + self.assertEqual(16 * GiB, inf2_xlarge.memMB) + self.assertEqual({NEURON_DEVICE: 1}, inf2_xlarge.devices) + + inf2_8xlarge = aws_inf2_8xlarge() + self.assertEqual(32, inf2_8xlarge.cpu) + self.assertEqual(0, inf2_8xlarge.gpu) + self.assertEqual(128 * GiB, inf2_8xlarge.memMB) + self.assertEqual({NEURON_DEVICE: 1}, inf2_8xlarge.devices) + + inf2_24xlarge = aws_inf2_24xlarge() + self.assertEqual(96, inf2_24xlarge.cpu) + self.assertEqual(0, inf2_24xlarge.gpu) + self.assertEqual(384 * GiB, inf2_24xlarge.memMB) + self.assertEqual({NEURON_DEVICE: 6}, inf2_24xlarge.devices) + + inf2_48xlarge = aws_inf2_48xlarge() + self.assertEqual(192, inf2_48xlarge.cpu) + self.assertEqual(0, inf2_48xlarge.gpu) + self.assertEqual(768 * GiB, inf2_48xlarge.memMB) + self.assertEqual({NEURON_DEVICE: 12}, inf2_48xlarge.devices) + def test_aws_m5_2xlarge(self) -> None: resource = aws_m5_2xlarge() self.assertEqual(8, resource.cpu)