diff --git a/task/aws/resources/data_source_image.go b/task/aws/resources/data_source_image.go index ee46a7f6..aa821b58 100644 --- a/task/aws/resources/data_source_image.go +++ b/task/aws/resources/data_source_image.go @@ -34,6 +34,7 @@ func (i *Image) Read(ctx context.Context) error { image := i.Identifier images := map[string]string{ "ubuntu": "ubuntu@099720109477:x86_64:*ubuntu/images/hvm-ssd/ubuntu-focal-20.04*", + "nvidia": "ubuntu@679593333241:x86_64:NVIDIA Deep Learning AMI v21.02.2-*", } if val, ok := images[image]; ok { image = val diff --git a/task/az/resources/resource_virtual_machine_scale_set.go b/task/az/resources/resource_virtual_machine_scale_set.go index 77e33103..af740cbb 100644 --- a/task/az/resources/resource_virtual_machine_scale_set.go +++ b/task/az/resources/resource_virtual_machine_scale_set.go @@ -82,12 +82,13 @@ func (v *VirtualMachineScaleSet) Create(ctx context.Context) error { image := v.Attributes.Environment.Image images := map[string]string{ "ubuntu": "ubuntu@Canonical:0001-com-ubuntu-server-focal:20_04-lts:latest", + "nvidia": "ubuntu@nvidia:ngc_base_image_version_b:gen2_21-11-0:latest#plan", } if val, ok := images[image]; ok { image = val } - imageParts := regexp.MustCompile(`^([^@]+)@([^:]+):([^:]+):([^:]+):([^:]+)$`).FindStringSubmatch(image) + imageParts := regexp.MustCompile(`^([^@]+)@([^:]+):([^:]+):([^:]+):([^:]+)(:?(#plan)?)$`).FindStringSubmatch(image) if imageParts == nil { return errors.New("invalid machine image format: use publisher:offer:sku:version") } @@ -97,6 +98,7 @@ func (v *VirtualMachineScaleSet) Create(ctx context.Context) error { offer := imageParts[3] sku := imageParts[4] version := imageParts[5] + plan := imageParts[6] size := v.Attributes.Size.Machine sizes := map[string]string{ @@ -185,6 +187,14 @@ func (v *VirtualMachineScaleSet) Create(ctx context.Context) error { }, } + if plan == "#plan" { + settings.Plan = &compute.Plan{ + Publisher: to.StringPtr(publisher), + Product: to.StringPtr(offer), + Name: to.StringPtr(sku), + } + } + spot := v.Attributes.Spot if spot >= 0 { if spot == 0 { diff --git a/task/gcp/resources/data_source_image.go b/task/gcp/resources/data_source_image.go index 829de386..6a0fdcb9 100644 --- a/task/gcp/resources/data_source_image.go +++ b/task/gcp/resources/data_source_image.go @@ -32,6 +32,7 @@ func (i *Image) Read(ctx context.Context) error { image := i.Identifier images := map[string]string{ "ubuntu": "ubuntu@ubuntu-os-cloud/ubuntu-2004-lts", + "nvidia": "ubuntu@nvidia-ngc-public/nvidia-gpu-cloud-image-20211105", } if val, ok := images[image]; ok { image = val @@ -39,18 +40,27 @@ func (i *Image) Read(ctx context.Context) error { match := regexp.MustCompile(`^([^@]+)@([^/]+)/([^/]+)$`).FindStringSubmatch(image) if match == nil { - return common.NotFoundError + return errors.New("wrong image name") } i.Attributes.SSHUser = match[1] project := match[2] - family := match[3] + imageOrFamily := match[3] - resource, err := i.Client.Services.Compute.Images.GetFromFamily(project, family).Do() + resource, err := i.Client.Services.Compute.Images.Get(project, imageOrFamily).Do() if err != nil { var e *googleapi.Error if errors.As(err, &e) && e.Code == 404 { - return common.NotFoundError + resource, err := i.Client.Services.Compute.Images.GetFromFamily(project, imageOrFamily).Do() + if err != nil { + var e *googleapi.Error + if errors.As(err, &e) && e.Code == 404 { + return common.NotFoundError + } + return err + } + i.Resource = resource + return nil } return err } diff --git a/task/k8s/resources/resource_job.go b/task/k8s/resources/resource_job.go index 60b3dcd3..bcfbd329 100644 --- a/task/k8s/resources/resource_job.go +++ b/task/k8s/resources/resource_job.go @@ -68,11 +68,19 @@ func (j *Job) Create(ctx context.Context) error { "l+v100": "32-256000+nvidia-tesla-v100*4", "xl+v100": "64-512000+nvidia-tesla-v100*8", } - if val, ok := sizes[size]; ok { size = val } + image := j.Attributes.Task.Environment.Image + images := map[string]string{ + "ubuntu": "ubuntu", + "nvidia": "nvidia/cuda", + } + if val, ok := images[image]; ok { + image = val + } + match := regexp.MustCompile(`^(\d+)-(\d+)(?:\+([^*]+)\*([1-9]\d*))?$`).FindStringSubmatch(size) if match == nil { return common.NotFoundError @@ -206,7 +214,7 @@ func (j *Job) Create(ctx context.Context) error { Containers: []kubernetes_core.Container{ { Name: j.Identifier, - Image: j.Attributes.Task.Environment.Image, + Image: image, Resources: kubernetes_core.ResourceRequirements{ Limits: jobLimits, Requests: kubernetes_core.ResourceList{