Skip to content

Commit

Permalink
add UT for model type
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Aug 28, 2024
1 parent 795ff69 commit db00ec3
Showing 1 changed file with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.Assert;
import org.junit.Before;
Expand All @@ -27,13 +28,17 @@ public class GuardrailsTests {
String[] regex;
LocalRegexGuardrail inputLocalRegexGuardrail;
LocalRegexGuardrail outputLocalRegexGuardrail;
ModelGuardrail inputModelGuardrail;
ModelGuardrail outputModelGuardrail;

@Before
public void setUp() {
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
regex = List.of("regex1").toArray(new String[0]);
inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
inputModelGuardrail = new ModelGuardrail(Map.of("model_id", "guardrail_model_id", "response_validation_regex", "accept"));
outputModelGuardrail = new ModelGuardrail(Map.of("model_id", "guardrail_model_id", "response_validation_regex", "accept"));
}

@Test
Expand Down Expand Up @@ -103,4 +108,24 @@ public void parseNonType() throws IOException {
Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail);
Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail);
}

@Test
public void parseModelType() throws IOException {
String jsonStr = "{\"type\":\"model\","
+ "\"input_guardrail\":{\"model_id\":\"guardrail_model_id\",\"response_validation_regex\":\"accept\"},"
+ "\"output_guardrail\":{\"model_id\":\"guardrail_model_id\",\"response_validation_regex\":\"accept\"}}";
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
jsonStr
);
parser.nextToken();
Guardrails guardrails = Guardrails.parse(parser);

Assert.assertEquals(guardrails.getType(), "model");
Assert.assertEquals(guardrails.getInputGuardrail(), inputModelGuardrail);
Assert.assertEquals(guardrails.getOutputGuardrail(), outputModelGuardrail);
}
}

0 comments on commit db00ec3

Please sign in to comment.