diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index e5806103a9..7f5372df68 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -18,15 +18,13 @@ import java.util.Optional; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.transport.model.MLModelGetAction; -import org.opensearch.ml.common.transport.model.MLModelGetRequest; -import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.model.MLModelManager; @@ -91,16 +89,14 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } return channel -> { - MLModelGetRequest getModelRequest = new MLModelGetRequest(modelId, false); - ActionListener listener = ActionListener.wrap(r -> { - MLModel mlModel = r.getMlModel(); + ActionListener listener = ActionListener.wrap(mlModel -> { String algoName = mlModel.getAlgorithm().name(); client - .execute( - MLPredictionTaskAction.INSTANCE, - getRequest(modelId, algoName, request), - new RestToXContentListener<>(channel) - ); + .execute( + MLPredictionTaskAction.INSTANCE, + getRequest(modelId, algoName, request), + new RestToXContentListener<>(channel) + ); }, e -> { log.error("Failed to get ML model", e); try { @@ -109,8 +105,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client log.error("Failed to send error response", ex); } }); - client.execute(MLModelGetAction.INSTANCE, getModelRequest, listener); - + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + modelManager.getModel(modelId, ActionListener.runBefore(listener, () -> context.restore())); + } }; }