forked from GoogleCloudPlatform/ai-on-gke
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TPU Provisioner] Create admission controller (GoogleCloudPlatform#687)
* initial commit of admission controller * use jsonpatch library
- Loading branch information
1 parent
62f8d3c
commit bca9b96
Showing
12 changed files
with
442 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# don't add certificates | ||
certificates/*.crt | ||
certificates/*.key | ||
|
||
__pycache__/ | ||
.pytest_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FROM python:3.10-slim-buster | ||
WORKDIR /webhook | ||
COPY requirements.txt /webhook | ||
COPY admission_controller.py /webhook | ||
RUN pip install --no-cache-dir --upgrade -r /webhook/requirements.txt | ||
CMD ["uvicorn", "admission_controller:app", "--host", "0.0.0.0", "--port", "5000","--ssl-keyfile=/certs/webhook.key", "--ssl-certfile=/certs/webhook.crt"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# TPU Provisioner Admission Controller | ||
|
||
This is a custom k8s admission controller that can be paired with the TPU provisioner | ||
to dynamically inject node selectors into a Job's pod template based on environment | ||
variables. The TPU provisioner will then provision slices based on the values of | ||
these node selectors. | ||
|
||
**NOTE**: This is not a generalized solution that works out of the box for any user - the values | ||
injected by the admission controller are just examples that the user is responsible | ||
for changing to fit their use case. | ||
|
||
## Project Structure | ||
|
||
``` | ||
|- admission_controller.py (mutating webhook) | ||
|- certificates (add TLS certificates here) | ||
|- manifests (deployment manifest for admission controller) | ||
|- test (unit tests) | ||
| - tests | ||
| |-- admission_controller_test.py (unit tests) | ||
| |-- manual_e2e/ (JobSet manifests for manual e2e tests) | ||
| | ... | ||
``` | ||
|
||
### Prepare container image | ||
|
||
1. Build container image: `docker build admission-controller -f Dockerfile .` | ||
2. Tag container image: `docker tag admission-controller gcr.io/${PROJECT}/admission-controller:v0.1.0` | ||
2. Push container image: `docker push gcr.io/${PROJECT}/admission-controller:v0.1.0` | ||
|
||
Update the Deployment in `manifests/manifest.yaml` with this container image. | ||
|
||
### Run Unit tests | ||
|
||
This project uses [pytest](https://docs.pytest.org) for unit testing. | ||
|
||
To run unit tests, run the command `pytest` from the `admission_controller/` directory. | ||
|
||
### Run E2E tests | ||
|
||
E2E testing is currently done manually via the following steps: | ||
|
||
1. [Install JobSet](https://jobset.sigs.k8s.io/docs/installation/) | ||
2. **Deploy admission controller**: Run `kubectl apply -f manifests/` from the `admission_controller/` directory. | ||
3. **Create a test JobSet**: Run `kubectl apply -f test/test-jobset.yaml` | ||
4. **Check Jobs were mutated correctly**: Run `kubectl describe jobs` and view the nodeSelectors in the pod template. |
Empty file.
103 changes: 103 additions & 0 deletions
103
tpu-provisioner/admission_controller/admission_controller.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#!/usr/bin/env python3 | ||
import os | ||
import json | ||
import base64 | ||
import logging | ||
import hashlib | ||
from fastapi import FastAPI, Body | ||
from jsonpatch import JsonPatch | ||
from copy import deepcopy | ||
|
||
app = FastAPI() | ||
|
||
webhook_logger = logging.getLogger(__name__) | ||
webhook_logger.setLevel(logging.INFO) | ||
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s") | ||
|
||
# environment variables | ||
LOCATION_HINT = "RESERVATION_LOCATION_HINT" | ||
ALWAYS_HINT_TIME = "ALWAYS_HINT_TIME" | ||
FORCE_ON_DEMAND = "FORCE_ON_DEMAND" | ||
|
||
# labels | ||
job_key_label = "job-key" | ||
reservation_name_label = "cloud.google.com/reservation-name" | ||
gke_spot_label = "cloud.google.com/gke-spot" | ||
gke_location_hint_label = "cloud.google.com/gke-location-hint" | ||
|
||
# API endpoint | ||
@app.post("/mutate") | ||
def mutate_request(request: dict = Body(...)): | ||
'''API endpoint for the admission controller mutating webhook.''' | ||
uid: str = request["request"]["uid"] | ||
|
||
object_in: dict = request["request"]["object"] | ||
webhook_logger.info(f'Patching {object_in["kind"]} {object_in["metadata"]["namespace"]}/{object_in["metadata"]["name"]}') | ||
|
||
response: dict = admission_review(uid, object_in) | ||
webhook_logger.info(f'Response: {json.dumps(response)}') | ||
return response | ||
|
||
|
||
def admission_review(uid: str, object_in: dict) -> dict: | ||
'''Returns an AdmissionReview JSONPatch for the given AdmissionRequest.''' | ||
return { | ||
"apiVersion": "admission.k8s.io/v1", | ||
"kind": "AdmissionReview", | ||
"response": { | ||
"uid": uid, | ||
"allowed": True, | ||
"patchType": "JSONPatch", | ||
"status": {"message": f"Patched {object_in['kind']}: {object_in['metadata']['namespace']}/{object_in['metadata']['name']}"}, | ||
"patch": patch(object_in), | ||
}, | ||
} | ||
|
||
|
||
def patch(object_in: dict) -> str: | ||
'''Returns a base64 encoded patch for the given k8s object.''' | ||
patches: list[dict] = make_patches(object_in) | ||
return base64.b64encode(str(patches).encode()).decode() | ||
|
||
|
||
def make_patches(object_in: dict) -> JsonPatch: | ||
'''Generates a JsonPatch for Job mutations that are based on environment variables.''' | ||
job_name: str = object_in["metadata"]["name"] | ||
job_namespace: str = object_in["metadata"]["namespace"] | ||
modified_object: dict = deepcopy(object_in) | ||
|
||
if "nodeSelector" not in modified_object["spec"]["template"]["spec"]: | ||
modified_object["spec"]["template"]["spec"]["nodeSelector"] = {} | ||
|
||
# Add job-key node selector unconditionally. | ||
modified_object["spec"]["template"]["spec"]["nodeSelector"][job_key_label] = job_key_value(job_name, job_namespace) | ||
webhook_logger.info(f'Job: {job_name} Added nodeSelector: {job_key_label}: {job_key_value(job_name, job_namespace)}') | ||
|
||
if os.environ.get(FORCE_ON_DEMAND) == "true": | ||
# Remove reservation label if FORCE_ON_DEMAND is set. | ||
if reservation_name_label in modified_object["spec"]["template"]["spec"]["nodeSelector"]: | ||
del modified_object["spec"]["template"]["spec"]["nodeSelector"][reservation_name_label] | ||
webhook_logger.info(f'Job: {job_name} Removed nodeSelector for node label: {reservation_name_label}') | ||
# Remove spot label if FORCE_ON_DEMAND is set. | ||
if gke_spot_label in modified_object["spec"]["template"]["spec"]["nodeSelector"]: | ||
del modified_object["spec"]["template"]["spec"]["nodeSelector"][gke_spot_label] | ||
webhook_logger.info(f'Job: {job_name} Removed nodeSelector for node label: {gke_spot_label}') | ||
|
||
# Set location hint nodeSelector if RESERVATION_LOCATION_HINT is set. | ||
location_hint_value: str = os.environ.get(LOCATION_HINT, "") | ||
if location_hint_value != "": | ||
modified_object["spec"]["template"]["spec"]["nodeSelector"][gke_location_hint_label] = location_hint_value | ||
webhook_logger.info(f'Job: {job_name} Added nodeSelector: {gke_location_hint_label}: {location_hint_value}') | ||
|
||
patch: JsonPatch = JsonPatch.from_diff(object_in, modified_object) | ||
return patch | ||
|
||
|
||
def job_key_value(job_name: str, job_namespace: str) -> str: | ||
'''Returns the SHA1 hash of the namespaced Job name.''' | ||
return sha1(f'{job_namespace}/{job_name}') | ||
|
||
|
||
def sha1(data: str) -> str: | ||
'''Returns the SHA1 hash of the given string.''' | ||
return hashlib.sha1(data.encode()).hexdigest() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Two files are required in this directory: | ||
|
||
1. `certificate.crt` | ||
2. `private.key` | ||
|
||
|
||
These are used to configure TLS for network communication to/from the webhook. |
78 changes: 78 additions & 0 deletions
78
tpu-provisioner/admission_controller/manifests/manifest.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
apiVersion: v1 | ||
kind: Secret | ||
metadata: | ||
name: admission-tls | ||
type: Opaque | ||
data: | ||
webhook.crt: "" # base64 encoded certificate | ||
webhook.key: "" # base64 encoded private key | ||
--- | ||
apiVersion: v1 | ||
kind: Service | ||
metadata: | ||
name: mutating-webhook | ||
spec: | ||
selector: | ||
app: mutating-webhook | ||
ports: | ||
- port: 5000 | ||
--- | ||
apiVersion: admissionregistration.k8s.io/v1 | ||
kind: MutatingWebhookConfiguration | ||
metadata: | ||
name: mutating-webhook | ||
webhooks: | ||
- name: mutating-webhook.default.svc | ||
matchPolicy: Equivalent | ||
admissionReviewVersions: ["v1"] | ||
sideEffects: None | ||
rules: | ||
- operations: ["CREATE"] | ||
apiGroups: ["batch"] | ||
apiVersions: ["v1"] | ||
resources: ["jobs"] | ||
scope: "Namespaced" | ||
failurePolicy: Ignore | ||
timeoutSeconds: 20 | ||
clientConfig: | ||
caBundle: # base64 CA bundle here | ||
service: | ||
namespace: default | ||
name: mutating-webhook | ||
path: /mutate | ||
port: 5000 | ||
--- | ||
apiVersion: apps/v1 | ||
kind: Deployment | ||
metadata: | ||
name: mutating-webhook | ||
spec: | ||
replicas: 1 | ||
selector: | ||
matchLabels: | ||
app: mutating-webhook | ||
template: | ||
metadata: | ||
labels: | ||
app: mutating-webhook | ||
spec: | ||
containers: | ||
- name: mutating-webhook | ||
image: "" # build container image, push to repository and add it here | ||
imagePullPolicy: Always | ||
ports: | ||
- containerPort: 5000 | ||
env: | ||
# Set environment variables for your deployment. | ||
- name: RESERVATION_LOCATION_HINT | ||
value: "cell" | ||
- name: FORCE_ON_DEMAND | ||
value: "false" | ||
volumeMounts: | ||
- name: certs-volume | ||
readOnly: true | ||
mountPath: "/certs" | ||
volumes: | ||
- name: certs-volume | ||
secret: | ||
secretName: admission-tls |
Binary file not shown.
Empty file.
118 changes: 118 additions & 0 deletions
118
tpu-provisioner/admission_controller/test/admission_controller_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os | ||
import pytest | ||
from jsonpatch import JsonPatch | ||
|
||
from ..admission_controller import * | ||
|
||
test_job_name = "test-job" | ||
test_job_ns = "default" | ||
|
||
# RFC 6901: escaped forward slash '/' in JSON pointer is encoded as '~1': https://datatracker.ietf.org/doc/html/rfc6901#section-3 | ||
# This is cleaned up before sending the AdmissionReview back to the apiserver, but these unit tests | ||
# validate the JsonPatch objects themselves, before cleanup. | ||
escaped_reservation_label = reservation_name_label.replace('/', '~1') | ||
escaped_gke_spot_label = gke_spot_label.replace('/', '~1') | ||
escaped_gke_location_hint_label = gke_location_hint_label.replace('/', '~1') | ||
|
||
@pytest.fixture(autouse=True) | ||
def clear_environ(mocker): | ||
"""Clears environment variables before each test.""" | ||
mocker.patch.dict('os.environ', clear=True) | ||
|
||
|
||
def test_base_patch_existing_nodeselector(mocker): | ||
"""Tests the basic patch operation (adding job-key selector).""" | ||
object_in = { | ||
"kind": "Job", | ||
"metadata": {"name": test_job_name, "namespace": test_job_ns}, | ||
"spec": {"template": {"spec": {"nodeSelector": {"test-key": "test-value"}}}} | ||
} | ||
|
||
expected_patches = JsonPatch([ | ||
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{job_key_label}', 'value': job_key_value(test_job_name, test_job_ns)}, | ||
]) | ||
|
||
patches = make_patches(object_in) | ||
assert ordered(patches.patch) == ordered(expected_patches.patch) | ||
|
||
|
||
def test_base_patch_empty_nodeselector(mocker): | ||
"""Tests the basic patch operation (adding job-key selector).""" | ||
object_in = { | ||
"kind": "Job", | ||
"metadata": {"name": test_job_name, "namespace": test_job_ns}, | ||
"spec": {"template": {"spec": {}}} | ||
} | ||
|
||
expected_patches = JsonPatch([ | ||
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector', 'value': {job_key_label: job_key_value(test_job_name, test_job_ns)}}, | ||
]) | ||
patches = make_patches(object_in) | ||
assert ordered(patches.patch) == ordered(expected_patches.patch) | ||
|
||
|
||
def test_force_on_demand(mocker): | ||
"""Tests patch operations when FORCE_ON_DEMAND is set.""" | ||
mocker.patch.dict('os.environ', {FORCE_ON_DEMAND: "true"}) | ||
|
||
object_in = { | ||
"kind": "Job", | ||
"metadata": {"name": test_job_name, "namespace": test_job_ns}, | ||
"spec": { | ||
"template": { | ||
"spec": { | ||
"nodeSelector": { | ||
gke_spot_label: "true", | ||
reservation_name_label: "my-reservation", | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
expected_patches = JsonPatch([ | ||
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{job_key_label}', 'value': job_key_value(test_job_name, test_job_ns)}, | ||
{'op': 'remove', 'path': f'/spec/template/spec/nodeSelector/{escaped_reservation_label}'}, | ||
{'op': 'remove', 'path': f'/spec/template/spec/nodeSelector/{escaped_gke_spot_label}'}, | ||
]) | ||
|
||
patches = make_patches(object_in) | ||
assert ordered(patches.patch) == ordered(expected_patches.patch) | ||
|
||
|
||
def test_location_hint_with_reservation(mocker): | ||
"""Tests patch operations with LOCATION_HINT and using a reservation.""" | ||
mocker.patch.dict('os.environ', {LOCATION_HINT: "cell"}) | ||
|
||
object_in = { | ||
"kind": "JobSet", | ||
"metadata": {"name": test_job_name, "namespace": test_job_ns}, | ||
"spec": { | ||
"template": { | ||
"spec": { | ||
"nodeSelector": { | ||
reservation_name_label: "my-reservation", | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
expected_patches = JsonPatch([ | ||
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{job_key_label}', 'value': job_key_value(test_job_name, test_job_ns)}, | ||
{'op': 'add', 'path': f'/spec/template/spec/nodeSelector/{escaped_gke_location_hint_label}', 'value': 'cell'}, | ||
]) | ||
|
||
patches = make_patches(object_in) | ||
assert ordered(patches.patch) == ordered(expected_patches.patch) | ||
|
||
|
||
def ordered(obj): | ||
'''Recursively sort a dictionary or list of dictionaries. | ||
Used for equality comparison of two JSON objects.''' | ||
if isinstance(obj, dict): | ||
return sorted((k, ordered(v)) for k, v in obj.items()) | ||
if isinstance(obj, list): | ||
return sorted(ordered(x) for x in obj) | ||
else: | ||
return obj |
Oops, something went wrong.