Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly designate model state for actively training models when nodes crash or leave cluster #1317

Merged
merged 40 commits into from
Dec 12, 2023

Conversation

ryanbogan
Copy link
Member

@ryanbogan ryanbogan commented Nov 20, 2023

Description

There is currently a bug where models will be stuck in the state TRAINING when a node crashes or leaves the cluster. Since there is a write block on training models, they cannot be removed even though they are not actually training. This PR marks the models as their proper state (either ZOMBIE or FAILED) when a node crashes or leaves the cluster, so that the zombie models can be deleted.

Issues Resolved

#837

Check List

  • New functionality includes testing.
    • All tests pass
  • New functionality has been documented.
    • New functionality has javadoc added
  • Commits are signed as per the DCO using --signoff

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

Copy link

codecov bot commented Nov 20, 2023

Codecov Report

Attention: 29 lines in your changes are missing coverage. Please review.

Comparison is base (2e3ab95) 85.15% compared to head (b6b85a9) 85.00%.

Files Patch % Lines
.../knn/training/TrainingJobClusterStateListener.java 78.26% 12 Missing and 3 partials ⚠️
...org/opensearch/knn/training/TrainingJobRunner.java 22.22% 6 Missing and 1 partial ⚠️
...java/org/opensearch/knn/indices/ModelMetadata.java 88.88% 1 Missing and 3 partials ⚠️
.../main/java/org/opensearch/knn/index/IndexUtil.java 66.66% 1 Missing and 1 partial ⚠️
...plugin/transport/TrainingModelTransportAction.java 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##               main    #1317      +/-   ##
============================================
- Coverage     85.15%   85.00%   -0.16%     
- Complexity     1216     1241      +25     
============================================
  Files           160      161       +1     
  Lines          4958     5067     +109     
  Branches        457      473      +16     
============================================
+ Hits           4222     4307      +85     
- Misses          538      555      +17     
- Partials        198      205       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Ryan Bogan <[email protected]>
Signed-off-by: Ryan Bogan <[email protected]>
@ryanbogan ryanbogan requested a review from jmazanec15 December 6, 2023 17:10
Signed-off-by: Ryan Bogan <[email protected]>
public void clusterChanged(ClusterChangedEvent event) {
if (event.localNodeClusterManager()) {
if (event.isNewCluster()) {
// When the cluster is first created, the cluster manager will update models that are still marked as training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which scenario can this happen? How there will be a training job when cluster first created?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cluster crashes completely, the model will still be marked as training even though the background job isn't running.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the case where index is restored?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with how the restoration code works, is it possible to overwrite system indices?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's skip the restoring case here to move things forward.
Please test this scenario and make sure we mark state as failed.

@jmazanec15
Copy link
Member

@heemin32 @navneet1v @ryanbogan Discussing with Ryan offline, it seems that it will be difficult to properly detect from the node that drops and rejoins, that it has dropped.

Therefore I think my opinion has come back to the following: on the training node, before we serialize the model after training in the JNI completes, we just need to check if the current state of the model (based on either uuid or combo of training node assignment and model name) in the metadata (or in the system index for now) is FAILED or is not there and, if so, cancel serialization.

If a node drops, and the cluster-manager detects it, the cluster state (or model index) will be updated to FAILED for that model. And when the node re-joins, it will get this updated cluster state and see its not there or FAILED. If the node drops and the cluster-manager does not detect it, it doesnt matter - there is no need to cancel the job because the model will not be marked as FAILED - the cluster will still think that it is TRAINING.

Its not perfect for sure, but I think its good enough for this particular use case for now. In general, the cluster may behave weirdly if nodes are going up and down anyway. As long as we can get it in a consistent state eventually we should be okay. We can do this manually by either restarting the node, or deleting the model that was trained during instability and asking user to re-train with a more stable cluster state.

@heemin32
Copy link
Collaborator

heemin32 commented Dec 7, 2023

Good catch. I think it was either by "cancel serialization on rejoin" or "cancel serialization on invalid state". Somehow we ended up using both of them but I agree that using one of them should be suffice. "cancel serialization on invalid state" would be simpler to implement and test than "cancel serialization on rejoin".

Signed-off-by: Ryan Bogan <[email protected]>
Signed-off-by: Ryan Bogan <[email protected]>
jmazanec15
jmazanec15 previously approved these changes Dec 7, 2023
Copy link
Member

@jmazanec15 jmazanec15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks! Make sure to add labels to the PR (bug fixes, v2.12.0 and backport-2.x)

@ryanbogan ryanbogan added Bug Fixes Changes to a system or product designed to handle a programming bug/glitch backport 2.x v2.12.0 labels Dec 7, 2023
public void clusterChanged(ClusterChangedEvent event) {
if (event.localNodeClusterManager()) {
if (event.isNewCluster()) {
// When the cluster is first created, the cluster manager will update models that are still marked as training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's skip the restoring case here to move things forward.
Please test this scenario and make sure we mark state as failed.

Signed-off-by: Ryan Bogan <[email protected]>
Signed-off-by: Ryan Bogan <[email protected]>
Copy link
Collaborator

@heemin32 heemin32 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.

@ryanbogan ryanbogan requested a review from jmazanec15 December 8, 2023 16:00
@ryanbogan
Copy link
Member Author

ryanbogan commented Dec 12, 2023

Manual testing was conducted using the following python script:

import random
import sys
import time
import json
from opensearchpy import OpenSearch, RequestsHttpConnection


def _get_model_body():
    return {
                "name": "hnsw",
                "engine": "faiss",
                "space_type": "l2",
                "parameters": {
                    "m": 16,
                    "ef_construction": 128,
                    "encoder": {
                        "name": "pq",
                        "parameters": {
                            "code_size": 8,
                            "m": 32
                        }
                    }
                }
            }
# def _get_model_body():
#     return {
#                 "name": "ivf",
#                 "engine": "faiss",
#                 "space_type": "l2",
#                 "parameters": {
#                     "nlist": 4096,
#                     "nprobes": 64,
#                     "encoder": {
#                         "name": "pq",
#                         "parameters": {
#                             "code_size": 8,
#                             "m": 48
#                         }
#                     }
#                 }
#             }

def _get_test_body(field_name: str, dimension: int, model_id: str):
    return {
        'mappings': {
            'properties': {
                field_name: {
                    'type': 'knn_vector',
                    'dimension': dimension,
                    'model_id': model_id
                }
            }
        },
        'settings': {
            'index': {
                'knn': True,
            },
            'number_of_shards': 200,
            'number_of_replicas': 0,
        }
    }


def _get_train_body(field_name: str, dimension: int):
    return {
      'mappings': {
        'properties': {
            field_name: {
            'dimension': dimension,
            'type': 'knn_vector'
          }
        }
      },
      'settings': {
        'index': {
          'refresh_interval': '30s',
        },
        'number_of_shards': 1,
        'number_of_replicas': 0,
      }
    }


def create_index(os_client: OpenSearch, index_name: str, field_name: str, dimension: int, model_id: str = None):
    os_client.indices.delete(index=index_name, ignore=[400, 404])
    if model_id == None:
        os_client.indices.create(index=index_name, body=_get_train_body(field_name, dimension))
        return

    os_client.indices.create(index=index_name, body=_get_test_body(field_name, dimension, model_id))


def ingest_docs(os_client: OpenSearch, index_name: str, field_name: str, dimension: int, doc_count: int):
    bulk_size = 100

    def create_header(doc_id):
        return {'index': {'_index': index_name, '_id': doc_id}}

    def _bulk_transform(partition, offset: int):
        actions = []
        _ = [
            actions.extend([create_header(_id + offset), None]) for _id in range(len(partition))
        ]
        actions[1::2] = [_build_index_doc(vec) for vec in partition]
        return actions


    def _salt_vector(vec):
        return [v + random.random() for v in vec]


    def _build_index_doc(vec):
        return {field_name: _salt_vector(vec)}

    for i in range(0, doc_count, bulk_size):
        vectors = [[random.random() for _ in range(dimension)] for _ in range(bulk_size)]
        body = _bulk_transform(vectors, i)
        os_client.bulk(index=index_name, body=body)


def train_model(os_client: OpenSearch, train_index_name: str, train_field_name: str, dimension: int, model_id: str):
    timeout = 2400
    print(_get_model_body())
    body = {
        'training_index': train_index_name,
        'training_field': train_field_name,
        'description': "blah",
        'dimension': dimension,
        'method': _get_model_body(),
    }

    method = "POST"
    model_uri = "/_plugins/_knn/models/{}".format(model_id)
    os_client.transport.perform_request(method, "{}/_train".format(model_uri), body=body)

    start_time = time.time()
    while time.time() < start_time + timeout:
        time.sleep(1)
        model_response = os_client.transport.perform_request("GET", model_uri)
        print(model_response)
        if 'state' not in model_response.keys():
            continue

        if model_response['state'] == 'created':
            return
        print(model_response['state'])

        if model_response['state'] == 'failed':
            raise Exception("Failed to create model: {}".format(model_response))

    raise Exception('Failed to create model: {} within timeout {} seconds'
                    .format(model_id, timeout))


def search_index(os_client: OpenSearch, index_name: str, field_name: str, dimension: int, query_count: int):
    def get_body(vec):
        return {
            'size': 10,
            'query': {
                'knn': {
                    field_name: {
                        'vector': vec,
                        'k': 10
                    }
                }
            }
        }

    for i in range(query_count):
        print("Query count {}".format((i+1)))
        query_response = os_client.search(index=index_name,
                                          body=get_body([random.random() for _ in range(dimension)]),
                                          request_timeout=100)
        print(query_response)


def _get_opensearch_client(endpoint: str, port: int):
    return OpenSearch(
        hosts=[{
            'host': endpoint,
            'port': port
        }],
        use_ssl=False,
        verify_certs=False,
        connection_class=RequestsHttpConnection,
        timeout=60,
    )


def main(args):
    TRAIN_INDEX_NAME = "train_index"
    TRAIN_FIELD_NAME = "train_field"
    MODEL_ID = "test_model"
    TEST_INDEX_NAME = "test_index"
    TEST_FIELD_NAME = "test_field"
    DIMENSION = 128
    DOC_COUNT = 5000

    QUERY_COUNT = 1

    step = args[1]
    os_client = _get_opensearch_client("localhost", 9200)

    if step == "train_setup":
        create_index(os_client, TRAIN_INDEX_NAME, TRAIN_FIELD_NAME, DIMENSION)
        ingest_docs(os_client, TRAIN_INDEX_NAME, TRAIN_FIELD_NAME, DIMENSION, DOC_COUNT)
        os_client.indices.refresh(index=TRAIN_INDEX_NAME)
        return

    if step == "train":
        train_model(os_client, TRAIN_INDEX_NAME, TRAIN_FIELD_NAME, DIMENSION, MODEL_ID)
        return

    if step == "ingest":
        create_index(os_client, TEST_INDEX_NAME, TEST_FIELD_NAME, DIMENSION, model_id=MODEL_ID)
        ingest_docs(os_client, TEST_INDEX_NAME, TEST_FIELD_NAME, DIMENSION, DOC_COUNT)
        os_client.indices.refresh(index=TEST_INDEX_NAME)
        return

    if step == "search":
        search_index(os_client, TEST_INDEX_NAME, TEST_FIELD_NAME, DIMENSION, QUERY_COUNT)
        return


if __name__ == "__main__":
    main(sys.argv)

Single node cluster crash:

  1. In terminal 1, ./gradlew run
  2. In a separate terminal, python3 test.py train_setup
  3. In terminal 2, python3 test.py train
  4. In terminal 1, control + C to crash cluster
  5. In terminal 1, ./gradlew run --preserve-data
  6. Once cluster is up in running, curl or use postman to hit the get model API. Model should be failed and able to be deleted.

Multi-node cluster crash:

  1. Same steps as above but for each ./gradlew run, add -PnumNodes=3

Node leaving while cluster is still running:

  1. Navigate to /etc/pf.conf and add the following rule to the bottom of the file, which will block transport traffic to the specified port:
    1. block in quick inet proto { tcp, udp } from any to any port 9300
  2. sudo pfctl -f etc/pf.conf
  3. Add a log statement in TrainingJobClusterStateListener to print the node ephemeral ID in the clusterChanged() method
  4. In terminal 1, ./gradlew run -PnumNodes=3
  5. In terminal 2, python3 test.py train_setup
  6. In terminal 2, python3 test.py train
  7. Ensure that the node assignment printed out by the train script is the same as integ-test0 ephemeral id in terminal 1.
    1. If not, control + C and restart from step 4
  8. In terminal 3, sudo pfctl -e
  9. At this point, the cluster fails and once the checks fail three times, a new cluster manager node is elected.
  10. Once there is a new cluster manager node, the model will be marked as failed, which can be validated by curl/postman
  11. In terminal 3, sudo pfctl -d
  12. The cluster will stabilize as the node rejoins the cluster.
  13. Once the training completes, the log for “Skipping serialization of model” is printed in terminal 1.
  14. The model is still marked as failed, and can be deleted.

Copy link
Member

@jmazanec15 jmazanec15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ryanbogan ryanbogan merged commit 33da521 into opensearch-project:main Dec 12, 2023
48 of 49 checks passed
@opensearch-trigger-bot
Copy link
Contributor

The backport to 2.x failed:

The process '/usr/bin/git' failed with exit code 1

To backport manually, run these commands in your terminal:

# Fetch latest updates from GitHub
git fetch
# Create a new working tree
git worktree add .worktrees/backport-2.x 2.x
# Navigate to the new working tree
cd .worktrees/backport-2.x
# Create a new branch
git switch --create backport/backport-1317-to-2.x
# Cherry-pick the merged commit of this pull request and resolve the conflicts
git cherry-pick -x --mainline 1 33da521e0f98317b4700b62807e1d21b11f54a71
# Push it to GitHub
git push --set-upstream origin backport/backport-1317-to-2.x
# Go back to the original working tree
cd ../..
# Delete the working tree
git worktree remove .worktrees/backport-2.x

Then, create a pull request where the base branch is 2.x and the compare/head branch is backport/backport-1317-to-2.x.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backport 2.x Bug Fixes Changes to a system or product designed to handle a programming bug/glitch v2.12.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants