Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly designate model state for actively training models when nodes crash or leave cluster #1317

Merged
merged 40 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d9269b3
Initial implementation
ryanbogan Nov 6, 2023
945a4da
Fix compile errors for tests
ryanbogan Nov 9, 2023
b2fc712
Temporary tests
ryanbogan Nov 13, 2023
9e21f07
Ensure backwards compatibility and add zombie to model state enum
ryanbogan Nov 17, 2023
ad09839
Update current tests
ryanbogan Nov 17, 2023
0537111
Fix current integration tests
ryanbogan Nov 20, 2023
c25075f
Fix unit tests with new changes
ryanbogan Nov 20, 2023
a28ad42
Add unit tests
ryanbogan Nov 20, 2023
3f31741
Fix spotless
ryanbogan Nov 20, 2023
85ed0bf
Add changelog entry
ryanbogan Nov 20, 2023
f464d2e
Delete temporary test file
ryanbogan Nov 20, 2023
ba7d5f2
Remove temporary changes to build.gradle
ryanbogan Nov 20, 2023
91778e1
Add more backwards compatibility
ryanbogan Nov 21, 2023
de2c3aa
Attempt to fix bwc tests
ryanbogan Nov 21, 2023
14aa761
Fix spotless
ryanbogan Nov 21, 2023
62d0082
Remove star imports
ryanbogan Nov 21, 2023
47a3800
Add another unit test
ryanbogan Nov 21, 2023
c15dc9a
Modify unit test to increase coverage
ryanbogan Nov 21, 2023
c7e0dcf
Change unit test to increase coverage
ryanbogan Nov 21, 2023
25eab9c
Merge branch 'main' into model_stuck_train_state
ryanbogan Nov 21, 2023
66e787b
Add method description for clusterChanged
ryanbogan Nov 22, 2023
257623a
Address PR feedback
ryanbogan Nov 28, 2023
85635c3
Refactor into TrainingJobClusterStateListener
ryanbogan Nov 28, 2023
bbd3b47
Make node assignment final and added in the constructor of TrainingJob
ryanbogan Nov 29, 2023
ea73a16
Remove clusterService from TrainingJobRunner
ryanbogan Nov 29, 2023
012a76e
Address PR Feedback
ryanbogan Dec 1, 2023
613a28e
Add flag when node rejoins and check when serializing model
ryanbogan Dec 1, 2023
ac5df23
Address PR feedback
ryanbogan Dec 1, 2023
91ea4df
Merge branch 'main' into model_stuck_train_state
ryanbogan Dec 1, 2023
c7b6281
Address PR Feedback
ryanbogan Dec 4, 2023
bf20c77
Fix spotless
ryanbogan Dec 4, 2023
4148b28
Test new version check for StreamInput
ryanbogan Dec 5, 2023
c1bdac9
Remove check to test new method
ryanbogan Dec 5, 2023
fdabe9f
Add version check for stream input/output logic
ryanbogan Dec 5, 2023
3eb2375
Address PR Feedback
ryanbogan Dec 6, 2023
80574a2
Address PR Feedback
ryanbogan Dec 7, 2023
6f1a064
Address PR Feedback
ryanbogan Dec 7, 2023
bf4407d
Address PR Feedback
ryanbogan Dec 7, 2023
586797f
Address PR Feedback
ryanbogan Dec 7, 2023
b6b85a9
Merge branch main into model_stuck_train_state
ryanbogan Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)

>>>>>>> main
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,6 @@ public String modelIndexMapping(String fieldName, String modelId) throws IOExcep
}

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "");
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class KNNConstants {
public static final String MODEL_TIMESTAMP = "timestamp";
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@

public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
}
};

Expand Down Expand Up @@ -251,4 +255,12 @@
}
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion);
}

public static boolean isVersionOnOrAfterMinRequiredVersion(Version version, String key) {
Version minimalRequiredVersion = minimalRequiredVersionMap.get(key);
if (minimalRequiredVersion == null) {
return false;

Check warning on line 262 in src/main/java/org/opensearch/knn/index/IndexUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/IndexUtil.java#L262

Added line #L262 was not covered by tests
}
return version.onOrAfter(minimalRequiredVersion);
}
}
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp());
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());
}
};

Expand Down
80 changes: 65 additions & 15 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.opensearch.knn.indices;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -19,6 +20,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

Expand All @@ -34,7 +36,9 @@
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT;

@Log4j2
public class ModelMetadata implements Writeable, ToXContentObject {

private static final String DELIMITER = ",";
Expand All @@ -46,6 +50,7 @@
private AtomicReference<ModelState> state;
final private String timestamp;
final private String description;
final private String trainingNodeAssignment;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
private String error;

/**
Expand All @@ -54,6 +59,7 @@
* @param in Stream input
*/
public ModelMetadata(StreamInput in) throws IOException {
String tempTrainingNodeAssignment;
this.knnEngine = KNNEngine.getEngine(in.readString());
this.spaceType = SpaceType.getSpace(in.readString());
this.dimension = in.readInt();
Expand All @@ -64,6 +70,12 @@
// which is checked in constructor and setters
this.description = in.readString();
this.error = in.readString();

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
this.trainingNodeAssignment = in.readString();
} else {
this.trainingNodeAssignment = "";

Check warning on line 77 in src/main/java/org/opensearch/knn/indices/ModelMetadata.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L77

Added line #L77 was not covered by tests
}
}

/**
Expand All @@ -84,7 +96,8 @@
ModelState modelState,
String timestamp,
String description,
String error
String error,
String trainingNodeAssignment
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -104,6 +117,7 @@
this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null");
this.description = Objects.requireNonNull(description, "description must not be null");
this.error = Objects.requireNonNull(error, "error must not be null");
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
}

/**
Expand Down Expand Up @@ -169,6 +183,15 @@
return error;
}

/**
* getter for model's node assignment
*
* @return trainingNodeAssignment
*/
public String getNodeAssignment() {
return trainingNodeAssignment;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -197,7 +220,8 @@
getState().toString(),
timestamp,
description,
error
error,
trainingNodeAssignment
);
}

Expand Down Expand Up @@ -240,22 +264,36 @@
public static ModelMetadata fromString(String modelMetadataString) {
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);

if (modelMetadataArray.length != 7) {
// Training node assignment was added as a field in Version 2.12.0
// Because models can be created on older versions and the cluster can be upgraded after,
// we need to accept model metadata arrays both with and without the training node assignment.
if (modelMetadataArray.length == 7) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
log.debug("Model metadata array does not contain training node assignment. Assuming empty string.");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, "");
} else if (modelMetadataArray.length == 8) {
log.debug("Model metadata contains training node assignment");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, trainingNodeAssignment);
} else {
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or \"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\"."
);
}

KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];

return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error);
}

private static String objectToString(Object value) {
Expand All @@ -282,6 +320,11 @@
Object timestamp = modelSourceMap.get(KNNConstants.MODEL_TIMESTAMP);
Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION);
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);
Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
}

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
Expand All @@ -290,7 +333,8 @@
ModelState.getModelState(objectToString(state)),
objectToString(timestamp),
objectToString(description),
objectToString(error)
objectToString(error),
objectToString(trainingNodeAssignment)
);
return modelMetadata;
}
Expand All @@ -304,6 +348,9 @@
out.writeString(getTimestamp());
out.writeString(getDescription());
out.writeString(getError());
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
out.writeString(getNodeAssignment());
}
}

@Override
Expand All @@ -316,6 +363,9 @@
builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue());
builder.field(DIMENSION, getDimension());
builder.field(KNN_ENGINE, getKnnEngine().getName());
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
builder.field(MODEL_NODE_ASSIGNMENT, getNodeAssignment());
}
return builder;
}
}
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction;
import org.opensearch.knn.training.TrainingJobClusterStateListener;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.knn.training.VectorReader;
import org.opensearch.plugins.ActionPlugin;
Expand Down Expand Up @@ -200,10 +201,14 @@ public Collection<Object> createComponents(
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

clusterService.addListener(TrainingJobClusterStateListener.getInstance());

knnStats = new KNNStats();
return ImmutableList.of(knnStats);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.concurrent.ExecutionException;

/**
* Transport action that trains a model and serializes it to model system index
Expand Down Expand Up @@ -66,7 +67,8 @@
trainingDataEntryContext,
modelAnonymousEntryContext,
request.getDimension(),
request.getDescription()
request.getDescription(),
clusterService.localNode().getEphemeralId()
);

KNNCounter.TRAINING_REQUESTS.increment();
Expand All @@ -84,7 +86,7 @@
wrappedListener::onFailure
)
);
} catch (IOException e) {
} catch (IOException | ExecutionException | InterruptedException e) {

Check warning on line 89 in src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java#L89

Added line #L89 was not covered by tests
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
wrappedListener.onFailure(e);
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ public TrainingJob(
NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext,
NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext,
int dimension,
String description
String description,
String nodeAssignment
) {
// Generate random base64 string if one is not provided
this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID();
Expand All @@ -81,7 +82,8 @@ public TrainingJob(
ModelState.TRAINING,
ZonedDateTime.now(ZoneOffset.UTC).toString(),
description,
""
"",
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
nodeAssignment
),
null,
this.modelId
Expand Down
Loading
Loading