Skip to content

Commit

Permalink
t1
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 15, 2023
1 parent 48ad895 commit 4413026
Showing 1 changed file with 75 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.execute.agent;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

//Find owner , Xun?
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.AGENT})
public class AgentMLInput extends MLInput {
public static final String AGENT_ID_FIELD = "agent_id";
public static final String PARAMETERS_FIELD = "parameters";

@Getter @Setter
private String agentId;

@Builder(builderMethodName = "AgentMLInputBuilder")
public AgentMLInput(String agentId, FunctionName functionName, MLInputDataset inputDataset) {
this.agentId = agentId;
this.algorithm = functionName;
this.inputDataset = inputDataset;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(agentId);
}

public AgentMLInput(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
}

public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case AGENT_ID_FIELD:
agentId = parser.text();
break;
case PARAMETERS_FIELD:
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
break;
default:
parser.skipChildren();
break;
}
}
}

}

0 comments on commit 4413026

Please sign in to comment.