Skip to content

Commit

Permalink
[TPU Provisioner] Create admission controller (GoogleCloudPlatform#687)
Browse files Browse the repository at this point in the history
* initial commit of admission controller

* use jsonpatch library
  • Loading branch information
danielvegamyhre authored May 29, 2024
1 parent 62f8d3c commit bca9b96
Show file tree
Hide file tree
Showing 12 changed files with 442 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tpu-provisioner/admission_controller/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# don't add certificates
certificates/*.crt
certificates/*.key

__pycache__/
.pytest_cache/
6 changes: 6 additions & 0 deletions tpu-provisioner/admission_controller/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FROM python:3.10-slim-buster
WORKDIR /webhook
COPY requirements.txt /webhook
COPY admission_controller.py /webhook
RUN pip install --no-cache-dir --upgrade -r /webhook/requirements.txt
CMD ["uvicorn", "admission_controller:app", "--host", "0.0.0.0", "--port", "5000","--ssl-keyfile=/certs/webhook.key", "--ssl-certfile=/certs/webhook.crt"]
46 changes: 46 additions & 0 deletions tpu-provisioner/admission_controller/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# TPU Provisioner Admission Controller

This is a custom k8s admission controller that can be paired with the TPU provisioner
to dynamically inject node selectors into a Job's pod template based on environment
variables. The TPU provisioner will then provision slices based on the values of
these node selectors.

**NOTE**: This is not a generalized solution that works out of the box for any user - the values
injected by the admission controller are just examples that the user is responsible
for changing to fit their use case.

## Project Structure

```
|- admission_controller.py (mutating webhook)
|- certificates (add TLS certificates here)
|- manifests (deployment manifest for admission controller)
|- test (unit tests)
| - tests
| |-- admission_controller_test.py (unit tests)
| |-- manual_e2e/ (JobSet manifests for manual e2e tests)
| | ...
```

### Prepare container image

1. Build container image: `docker build admission-controller -f Dockerfile .`
2. Tag container image: `docker tag admission-controller gcr.io/${PROJECT}/admission-controller:v0.1.0`
2. Push container image: `docker push gcr.io/${PROJECT}/admission-controller:v0.1.0`

Update the Deployment in `manifests/manifest.yaml` with this container image.

### Run Unit tests

This project uses [pytest](https://docs.pytest.org) for unit testing.

To run unit tests, run the command `pytest` from the `admission_controller/` directory.

### Run E2E tests

E2E testing is currently done manually via the following steps:

1. [Install JobSet](https://jobset.sigs.k8s.io/docs/installation/)
2. **Deploy admission controller**: Run `kubectl apply -f manifests/` from the `admission_controller/` directory.
3. **Create a test JobSet**: Run `kubectl apply -f test/test-jobset.yaml`
4. **Check Jobs were mutated correctly**: Run `kubectl describe jobs` and view the nodeSelectors in the pod template.
Empty file.
103 changes: 103 additions & 0 deletions tpu-provisioner/admission_controller/admission_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3
import os
import json
import base64
import logging
import hashlib
from fastapi import FastAPI, Body
from jsonpatch import JsonPatch
from copy import deepcopy

app = FastAPI()

webhook_logger = logging.getLogger(__name__)
webhook_logger.setLevel(logging.INFO)
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s")

# environment variables
LOCATION_HINT = "RESERVATION_LOCATION_HINT"
ALWAYS_HINT_TIME = "ALWAYS_HINT_TIME"
FORCE_ON_DEMAND = "FORCE_ON_DEMAND"

# labels
job_key_label = "job-key"
reservation_name_label = "cloud.google.com/reservation-name"
gke_spot_label = "cloud.google.com/gke-spot"
gke_location_hint_label = "cloud.google.com/gke-location-hint"

# API endpoint
@app.post("/mutate")
def mutate_request(request: dict = Body(...)):
'''API endpoint for the admission controller mutating webhook.'''
uid: str = request["request"]["uid"]

object_in: dict = request["request"]["object"]
webhook_logger.info(f'Patching {object_in["kind"]} {object_in["metadata"]["namespace"]}/{object_in["metadata"]["name"]}')

response: dict = admission_review(uid, object_in)
webhook_logger.info(f'Response: {json.dumps(response)}')
return response


def admission_review(uid: str, object_in: dict) -> dict:
'''Returns an AdmissionReview JSONPatch for the given AdmissionRequest.'''
return {
"apiVersion": "admission.k8s.io/v1",
"kind": "AdmissionReview",
"response": {
"uid": uid,
"allowed": True,
"patchType": "JSONPatch",
"status": {"message": f"Patched {object_in['kind']}: {object_in['metadata']['namespace']}/{object_in['metadata']['name']}"},
"patch": patch(object_in),
},
}


def patch(object_in: dict) -> str:
'''Returns a base64 encoded patch for the given k8s object.'''
patches: list[dict] = make_patches(object_in)
return base64.b64encode(str(patches).encode()).decode()


def make_patches(object_in: dict) -> JsonPatch:
'''Generates a JsonPatch for Job mutations that are based on environment variables.'''
job_name: str = object_in["metadata"]["name"]
job_namespace: str = object_in["metadata"]["namespace"]
modified_object: dict = deepcopy(object_in)

if "nodeSelector" not in modified_object["spec"]["template"]["spec"]:
modified_object["spec"]["template"]["spec"]["nodeSelector"] = {}

# Add job-key node selector unconditionally.
modified_object["spec"]["template"]["spec"]["nodeSelector"][job_key_label] = job_key_value(job_name, job_namespace)
webhook_logger.info(f'Job: {job_name} Added nodeSelector: {job_key_label}: {job_key_value(job_name, job_namespace)}')

if os.environ.get(FORCE_ON_DEMAND) == "true":
# Remove reservation label if FORCE_ON_DEMAND is set.
if reservation_name_label in modified_object["spec"]["template"]["spec"]["nodeSelector"]:
del modified_object["spec"]["template"]["spec"]["nodeSelector"][reservation_name_label]
webhook_logger.info(f'Job: {job_name} Removed nodeSelector for node label: {reservation_name_label}')
# Remove spot label if FORCE_ON_DEMAND is set.
if gke_spot_label in modified_object["spec"]["template"]["spec"]["nodeSelector"]:
del modified_object["spec"]["template"]["spec"]["nodeSelector"][gke_spot_label]
webhook_logger.info(f'Job: {job_name} Removed nodeSelector for node label: {gke_spot_label}')

# Set location hint nodeSelector if RESERVATION_LOCATION_HINT is set.
location_hint_value: str = os.environ.get(LOCATION_HINT, "")
if location_hint_value != "":
modified_object["spec"]["template"]["spec"]["nodeSelector"][gke_location_hint_label] = location_hint_value
webhook_logger.info(f'Job: {job_name} Added nodeSelector: {gke_location_hint_label}: {location_hint_value}')

patch: JsonPatch = JsonPatch.from_diff(object_in, modified_object)
return patch


def job_key_value(job_name: str, job_namespace: str) -> str:
'''Returns the SHA1 hash of the namespaced Job name.'''
return sha1(f'{job_namespace}/{job_name}')


def sha1(data: str) -> str:
'''Returns the SHA1 hash of the given string.'''
return hashlib.sha1(data.encode()).hexdigest()
7 changes: 7 additions & 0 deletions tpu-provisioner/admission_controller/certificates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Two files are required in this directory:

1. `certificate.crt`
2. `private.key`


These are used to configure TLS for network communication to/from the webhook.
78 changes: 78 additions & 0 deletions tpu-provisioner/admission_controller/manifests/manifest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
apiVersion: v1
kind: Secret
metadata:
name: admission-tls
type: Opaque
data:
webhook.crt: "" # base64 encoded certificate
webhook.key: "" # base64 encoded private key
---
apiVersion: v1
kind: Service
metadata:
name: mutating-webhook
spec:
selector:
app: mutating-webhook
ports:
- port: 5000
---
apiVersion: admissionregistration.k8s.io/v1
kind: MutatingWebhookConfiguration
metadata:
name: mutating-webhook
webhooks:
- name: mutating-webhook.default.svc
matchPolicy: Equivalent
admissionReviewVersions: ["v1"]
sideEffects: None
rules:
- operations: ["CREATE"]
apiGroups: ["batch"]
apiVersions: ["v1"]
resources: ["jobs"]
scope: "Namespaced"
failurePolicy: Ignore
timeoutSeconds: 20
clientConfig:
caBundle: # base64 CA bundle here
service:
namespace: default
name: mutating-webhook
path: /mutate
port: 5000
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: mutating-webhook
spec:
replicas: 1
selector:
matchLabels:
app: mutating-webhook
template:
metadata:
labels:
app: mutating-webhook
spec:
containers:
- name: mutating-webhook
image: "" # build container image, push to repository and add it here
imagePullPolicy: Always
ports:
- containerPort: 5000
env:
# Set environment variables for your deployment.
- name: RESERVATION_LOCATION_HINT
value: "cell"
- name: FORCE_ON_DEMAND
value: "false"
volumeMounts:
- name: certs-volume
readOnly: true
mountPath: "/certs"
volumes:
- name: certs-volume
secret:
secretName: admission-tls
Binary file not shown.
Empty file.
118 changes: 118 additions & 0 deletions tpu-provisioner/admission_controller/test/admission_controller_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import pytest
from jsonpatch import JsonPatch

from ..admission_controller import *

test_job_name = "test-job"
test_job_ns = "default"

# RFC 6901: escaped forward slash '/' in JSON pointer is encoded as '~1': https://datatracker.ietf.org/doc/html/rfc6901#section-3
# This is cleaned up before sending the AdmissionReview back to the apiserver, but these unit tests
# validate the JsonPatch objects themselves, before cleanup.
escaped_reservation_label = reservation_name_label.replace('/', '~1')
escaped_gke_spot_label = gke_spot_label.replace('/', '~1')
escaped_gke_location_hint_label = gke_location_hint_label.replace('/', '~1')

@pytest.fixture(autouse=True)
def clear_environ(mocker):
"""Clears environment variables before each test."""
mocker.patch.dict('os.environ', clear=True)


def test_base_patch_existing_nodeselector(mocker):
"""Tests the basic patch operation (adding job-key selector)."""
object_in = {
"kind": "Job",
"metadata": {"name": test_job_name, "namespace": test_job_ns},
"spec": {"template": {"spec": {"nodeSelector": {"test-key": "test-value"}}}}
}

expected_patches = JsonPatch([
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{job_key_label}', 'value': job_key_value(test_job_name, test_job_ns)},
])

patches = make_patches(object_in)
assert ordered(patches.patch) == ordered(expected_patches.patch)


def test_base_patch_empty_nodeselector(mocker):
"""Tests the basic patch operation (adding job-key selector)."""
object_in = {
"kind": "Job",
"metadata": {"name": test_job_name, "namespace": test_job_ns},
"spec": {"template": {"spec": {}}}
}

expected_patches = JsonPatch([
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector', 'value': {job_key_label: job_key_value(test_job_name, test_job_ns)}},
])
patches = make_patches(object_in)
assert ordered(patches.patch) == ordered(expected_patches.patch)


def test_force_on_demand(mocker):
"""Tests patch operations when FORCE_ON_DEMAND is set."""
mocker.patch.dict('os.environ', {FORCE_ON_DEMAND: "true"})

object_in = {
"kind": "Job",
"metadata": {"name": test_job_name, "namespace": test_job_ns},
"spec": {
"template": {
"spec": {
"nodeSelector": {
gke_spot_label: "true",
reservation_name_label: "my-reservation",
}
}
}
}
}

expected_patches = JsonPatch([
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{job_key_label}', 'value': job_key_value(test_job_name, test_job_ns)},
{'op': 'remove', 'path': f'/spec/template/spec/nodeSelector/{escaped_reservation_label}'},
{'op': 'remove', 'path': f'/spec/template/spec/nodeSelector/{escaped_gke_spot_label}'},
])

patches = make_patches(object_in)
assert ordered(patches.patch) == ordered(expected_patches.patch)


def test_location_hint_with_reservation(mocker):
"""Tests patch operations with LOCATION_HINT and using a reservation."""
mocker.patch.dict('os.environ', {LOCATION_HINT: "cell"})

object_in = {
"kind": "JobSet",
"metadata": {"name": test_job_name, "namespace": test_job_ns},
"spec": {
"template": {
"spec": {
"nodeSelector": {
reservation_name_label: "my-reservation",
}
}
}
}
}

expected_patches = JsonPatch([
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{job_key_label}', 'value': job_key_value(test_job_name, test_job_ns)},
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{escaped_gke_location_hint_label}', 'value': 'cell'},
])

patches = make_patches(object_in)
assert ordered(patches.patch) == ordered(expected_patches.patch)


def ordered(obj):
'''Recursively sort a dictionary or list of dictionaries.
Used for equality comparison of two JSON objects.'''
if isinstance(obj, dict):
return sorted((k, ordered(v)) for k, v in obj.items())
if isinstance(obj, list):
return sorted(ordered(x) for x in obj)
else:
return obj
Loading

0 comments on commit bca9b96

Please sign in to comment.