From d11a15fdac337fe112209c4d5bf64328fb8d87cb Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Tue, 30 Jul 2024 16:08:31 +0200 Subject: [PATCH 01/10] draft fixing the xgboost import --- .../ranker/parser/XGBoostJsonParserV2.java | 263 ++++++++++++++++++ .../parser/XGBoostJsonParserV2Tests.java | 81 ++++++ 2 files changed, 344 insertions(+) create mode 100644 src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java create mode 100644 src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java new file mode 100644 index 00000000..ecadbd3a --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java @@ -0,0 +1,263 @@ +package com.o19s.es.ltr.ranker.parser; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.Normalizers; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.*; + +public class XGBoostJsonParserV2 implements LtrRankerParser { + + public static final String TYPE = "model/xgboost+json"; + + private static final Integer MISSING_NODE_ID = Integer.MAX_VALUE; + + @Override + public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { + XGBoostJsonParserV2.XGBoostDefinition modelDefinition; + try (XContentParser parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, + model) + ) { + modelDefinition = new XGBoostJsonParserV2.XGBoostDefinition(set, parser.map()); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse XGBoost object", e); + } + + NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getTrees(set); + float[] weights = new float[trees.length]; + Arrays.fill(weights, 1F); + return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.normalizer); + } + + enum SplitType { + NUMERICAL(0), + CATEGORICAL(1); + + private final int value; + + SplitType(int value) { + this.value = value; + } + + public static SplitType fromValue(int value) { + for (SplitType type : values()) { + if (type.value == value) { + return type; + } + } + throw new IllegalArgumentException("Unknown SplitType value: " + value); + } + } + + class Node { + int nodeid; + int left; + int right; + int parent; + int splitIdx; + float splitCond; + boolean defaultLeft; + SplitType splitType; + List categories; + float baseWeight; + float lossChg; + float sumHess; + + Node(int nodeid, int left, int right, int parent, int splitIdx, float splitCond, boolean defaultLeft, + SplitType splitType, List categories, float baseWeight, float lossChg, float sumHess) { + this.nodeid = nodeid; + this.left = left; + this.right = right; + this.parent = parent; + this.splitIdx = splitIdx; + this.splitCond = splitCond; + this.defaultLeft = defaultLeft; + this.splitType = splitType; + this.categories = categories; + this.baseWeight = baseWeight; + this.lossChg = lossChg; + this.sumHess = sumHess; + } + } + + class Tree { + int treeId; + List nodes; + + Tree(int treeId, List nodes) { + this.treeId = treeId; + this.nodes = nodes; + } + + float lossChange(int nodeId) { + return nodes.get(nodeId).lossChg; + } + + float sumHessian(int nodeId) { + return nodes.get(nodeId).sumHess; + } + + float baseWeight(int nodeId) { + return nodes.get(nodeId).baseWeight; + } + + int splitIndex(int nodeId) { + return nodes.get(nodeId).splitIdx; + } + + float splitCondition(int nodeId) { + return nodes.get(nodeId).splitCond; + } + + List splitCategories(int nodeId) { + return nodes.get(nodeId).categories; + } + + boolean isCategorical(int nodeId) { + return nodes.get(nodeId).splitType == SplitType.CATEGORICAL; + } + + boolean isNumerical(int nodeId) { + return !isCategorical(nodeId); + } + + int parent(int nodeId) { + return nodes.get(nodeId).parent; + } + + int leftChild(int nodeId) { + return nodes.get(nodeId).left; + } + + int rightChild(int nodeId) { + return nodes.get(nodeId).right; + } + + boolean isLeaf(int nodeId) { + return nodes.get(nodeId).left == -1 && nodes.get(nodeId).right == -1; + } + + boolean isSplit(int nodeId) { + return !this.isLeaf(nodeId); + } + + boolean isDeleted(int nodeId) { + return splitIndex(nodeId) == MISSING_NODE_ID; + } + + NaiveAdditiveDecisionTree.Node toLibNode(int nodeid) { + if (isSplit(nodeid)) { + Node node = nodes.get(nodeid); + return new NaiveAdditiveDecisionTree.Split(toLibNode(node.left), toLibNode(node.right), + node.splitIdx, node.splitCond, node.left, MISSING_NODE_ID); + } else { + Node node = nodes.get(nodeid); + return new NaiveAdditiveDecisionTree.Leaf(node.baseWeight); + } + } + } + + class XGBoostDefinition { + int numOutputGroup; + int numFeature; + float baseScore; + List treeInfo; + List trees; + Normalizer normalizer = Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + + XGBoostDefinition(FeatureSet set, Map modelStr) { + Map learnerModelShape = (Map) ((Map) modelStr.get("learner")).get("learner_model_param"); + this.numOutputGroup = Integer.parseInt(learnerModelShape.get("num_class")); + this.numFeature = Integer.parseInt(learnerModelShape.get("num_feature")); + this.baseScore = Float.parseFloat(learnerModelShape.get("base_score")); + + Map gradientBooster = (Map) ((Map) modelStr.get("learner")).get("gradient_booster"); + this.treeInfo = (List) gradientBooster.get("tree_info"); + Map model = (Map) gradientBooster.get("model"); + Map modelShape = (Map) model.get("gbtree_model_param"); + + List> treesObj = (List>) model.get("trees"); + this.trees = new ArrayList<>(); + int numTrees = Integer.parseInt(modelShape.get("num_trees")); + + for (int i = 0; i < numTrees; i++) { + Map tree = treesObj.get(i); + int treeId = (int) tree.get("id"); + + List leftChildren = (List) tree.get("left_children"); + List rightChildren = (List) tree.get("right_children"); + List parents = (List) tree.get("parents"); + List splitConditions = ((List) tree.get("split_conditions")).stream().map(n -> n.floatValue()).toList(); + List splitIndices = (List) tree.get("split_indices"); + + List defaultLeft = toIntegers((List) tree.get("default_left")); + List splitTypes = toIntegers((List) tree.get("split_type")); + + List catSegments = (List) tree.get("categories_segments"); + List catSizes = (List) tree.get("categories_sizes"); + List catNodes = (List) tree.get("categories_nodes"); + List cats = (List) tree.get("categories"); + + int catCnt = 0; + int lastCatNode = !catNodes.isEmpty() ? catNodes.get(catCnt) : -1; + List> nodeCategories = new ArrayList<>(); + + for (int nodeId = 0; nodeId < leftChildren.size(); nodeId++) { + if (nodeId == lastCatNode) { + int beg = catSegments.get(catCnt); + int size = catSizes.get(catCnt); + int end = beg + size; + List nodeCats = cats.subList(beg, end); + catCnt++; + lastCatNode = catCnt < catNodes.size() ? catNodes.get(catCnt) : -1; + nodeCategories.add(nodeCats); + } else { + nodeCategories.add(new ArrayList<>()); + } + } + + List baseWeights = ((List) tree.get("base_weights")).stream().map(n -> n.floatValue()).toList(); + List lossChanges = ((List) tree.get("loss_changes")).stream().map(n -> n.floatValue()).toList(); + List sumHessian = ((List) tree.get("sum_hessian")).stream().map(n -> n.floatValue()).toList(); + + List nodes = new ArrayList<>(); + for (int nodeId = 0; nodeId < leftChildren.size(); nodeId++) { + nodes.add(new Node( + nodeId, + leftChildren.get(nodeId), + rightChildren.get(nodeId), + parents.get(nodeId), + splitIndices.get(nodeId), + splitConditions.get(nodeId), + defaultLeft.get(nodeId) == 1, + SplitType.fromValue(splitTypes.get(nodeId)), + nodeCategories.get(nodeId), + baseWeights.get(nodeId), + lossChanges.get(nodeId), + sumHessian.get(nodeId) + )); + } + + trees.add(new Tree(treeId, nodes)); + } + } + + private List toIntegers(List data) { + return new ArrayList<>(data); + } + + NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { + NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[this.trees.size()]; + ListIterator it = this.trees.listIterator(); + while (it.hasNext()) { + trees[it.nextIndex()] = it.next().toLibNode(0); + } + return trees; + } + } +} \ No newline at end of file diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java new file mode 100644 index 00000000..9d7be98a --- /dev/null +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java @@ -0,0 +1,81 @@ +package com.o19s.es.ltr.ranker.parser; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.SparseFeatureVector; +import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.common.io.Streams; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.sql.Array; + +import static com.o19s.es.ltr.LtrTestUtils.randomFeatureSet; + +public class XGBoostJsonParserV2Tests extends LuceneTestCase { + private final XGBoostJsonParserV2 parser = new XGBoostJsonParserV2(); + + public void testReadLeaf() throws IOException { + String model = + "{\"learner\":{" + + "\"attributes\":{}," + + "\"feature_names\":[]," + + "\"feature_types\":[]," + + "\"gradient_booster\":{" + + "\"model\":{" + + "\"gbtree_model_param\":{" + + "\"num_parallel_tree\":\"1\"," + + "\"num_trees\":\"1\"}," + + "\"iteration_indptr\":[0,1]," + + "\"tree_info\":[0]," + + "\"trees\":[{" + + "\"base_weights\":[-0E0]," + + "\"categories\":[]," + + "\"categories_nodes\":[]," + + "\"categories_segments\":[]," + + "\"categories_sizes\":[]," + + "\"default_left\":[0]," + + "\"id\":0," + + "\"left_children\":[-1]," + + "\"loss_changes\":[0E0]," + + "\"parents\":[2147483647]," + + "\"right_children\":[-1]," + + "\"split_conditions\":[-0E0]," + + "\"split_indices\":[0]," + + "\"split_type\":[0]," + + "\"sum_hessian\":[1E0]," + + "\"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"2\",\"num_nodes\":\"1\",\"size_leaf_vector\":\"1\"}}" + + "]}," + + "\"name\":\"gbtree\"" + + "}," + + "\"learner_model_param\":{" + + "\"base_score\":\"5E-1\"," + + "\"boost_from_average\":\"1\"," + + "\"num_class\":\"0\"," + + "\"num_feature\":\"2\"," + + "\"num_target\":\"1\"}," + + "\"objective\":{" + + "\"name\":\"binary:logistic\"," + + "\"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + "}" + + "}," + + "\"version\":[2,1,0]}"; + FeatureSet set = randomFeatureSet(); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector featureVector = new SparseFeatureVector(2); + featureVector.setFeatureScore(0, 2); + featureVector.setFeatureScore(1, 3); + assertEquals(0.0, tree.score(featureVector), Math.ulp(0.1F)); + } + + private String readModel(String model) throws IOException { + try (InputStream is = this.getClass().getResourceAsStream(model)) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + Streams.copy(is.readAllBytes(), bos); + return bos.toString(StandardCharsets.UTF_8.name()); + } + } +} From 95376f5cc0c69081844f3cb7bc3a990d4a22db8f Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Tue, 30 Jul 2024 18:04:43 +0200 Subject: [PATCH 02/10] small changes before parking the pr --- .../ranker/parser/XGBoostJsonParserV2.java | 50 ++++++--- .../parser/XGBoostJsonParserV2Tests.java | 105 +++++++++--------- 2 files changed, 86 insertions(+), 69 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java index ecadbd3a..2dacb2b4 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java @@ -13,17 +13,17 @@ public class XGBoostJsonParserV2 implements LtrRankerParser { - public static final String TYPE = "model/xgboost+json"; + public static final String TYPE = "model/xgboost+json+v2"; private static final Integer MISSING_NODE_ID = Integer.MAX_VALUE; @Override public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { - XGBoostJsonParserV2.XGBoostDefinition modelDefinition; + XGBoostLearner modelDefinition; try (XContentParser parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, model) ) { - modelDefinition = new XGBoostJsonParserV2.XGBoostDefinition(set, parser.map()); + modelDefinition = new XGBoostLearner(set, parser.map(), parser); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse XGBoost object", e); } @@ -162,21 +162,26 @@ NaiveAdditiveDecisionTree.Node toLibNode(int nodeid) { } } - class XGBoostDefinition { + class XGBoostLearner { + int numOutputGroup; int numFeature; float baseScore; List treeInfo; List trees; - Normalizer normalizer = Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + Normalizer normalizer; - XGBoostDefinition(FeatureSet set, Map modelStr) { - Map learnerModelShape = (Map) ((Map) modelStr.get("learner")).get("learner_model_param"); + XGBoostLearner(FeatureSet set, Map modelObj, XContentParser parser) { + Map learner = (Map) modelObj.get("learner"); + Map learnerModelShape = (Map) learner.get("learner_model_param"); this.numOutputGroup = Integer.parseInt(learnerModelShape.get("num_class")); this.numFeature = Integer.parseInt(learnerModelShape.get("num_feature")); this.baseScore = Float.parseFloat(learnerModelShape.get("base_score")); - Map gradientBooster = (Map) ((Map) modelStr.get("learner")).get("gradient_booster"); + String normalizerName = (String) ((Map) learner.get("objective")).get("name"); + this.normalizer = getNormalizer(normalizerName); + + Map gradientBooster = (Map) learner.get("gradient_booster"); this.treeInfo = (List) gradientBooster.get("tree_info"); Map model = (Map) gradientBooster.get("model"); Map modelShape = (Map) model.get("gbtree_model_param"); @@ -192,11 +197,11 @@ class XGBoostDefinition { List leftChildren = (List) tree.get("left_children"); List rightChildren = (List) tree.get("right_children"); List parents = (List) tree.get("parents"); - List splitConditions = ((List) tree.get("split_conditions")).stream().map(n -> n.floatValue()).toList(); + List splitConditions = ((List) tree.get("split_conditions")).stream().map(Double::floatValue).toList(); List splitIndices = (List) tree.get("split_indices"); - List defaultLeft = toIntegers((List) tree.get("default_left")); - List splitTypes = toIntegers((List) tree.get("split_type")); + List defaultLeft = (List) tree.get("default_left"); + List splitTypes = (List) tree.get("split_type"); List catSegments = (List) tree.get("categories_segments"); List catSizes = (List) tree.get("categories_sizes"); @@ -221,9 +226,9 @@ class XGBoostDefinition { } } - List baseWeights = ((List) tree.get("base_weights")).stream().map(n -> n.floatValue()).toList(); - List lossChanges = ((List) tree.get("loss_changes")).stream().map(n -> n.floatValue()).toList(); - List sumHessian = ((List) tree.get("sum_hessian")).stream().map(n -> n.floatValue()).toList(); + List baseWeights = ((List) tree.get("base_weights")).stream().map(Double::floatValue).toList(); + List lossChanges = ((List) tree.get("loss_changes")).stream().map(Double::floatValue).toList(); + List sumHessian = ((List) tree.get("sum_hessian")).stream().map(Double::floatValue).toList(); List nodes = new ArrayList<>(); for (int nodeId = 0; nodeId < leftChildren.size(); nodeId++) { @@ -247,10 +252,6 @@ class XGBoostDefinition { } } - private List toIntegers(List data) { - return new ArrayList<>(data); - } - NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[this.trees.size()]; ListIterator it = this.trees.listIterator(); @@ -259,5 +260,18 @@ NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { } return trees; } + + Normalizer getNormalizer(String objectiveName) { + switch (objectiveName) { + case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> { + return Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + } + case "binary:logistic", "reg:logistic" -> { + return Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); + } + default -> + throw new IllegalArgumentException("Objective [" + objectiveName + "] is not a valid XGBoost objective"); + } + } } } \ No newline at end of file diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java index 9d7be98a..54c2b75b 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java @@ -1,6 +1,7 @@ package com.o19s.es.ltr.ranker.parser; import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; import com.o19s.es.ltr.ranker.SparseFeatureVector; import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; @@ -13,69 +14,71 @@ import java.nio.charset.StandardCharsets; import java.sql.Array; +import static com.o19s.es.ltr.LtrTestUtils.randomFeature; import static com.o19s.es.ltr.LtrTestUtils.randomFeatureSet; +import static java.util.Collections.singletonList; public class XGBoostJsonParserV2Tests extends LuceneTestCase { private final XGBoostJsonParserV2 parser = new XGBoostJsonParserV2(); public void testReadLeaf() throws IOException { String model = - "{\"learner\":{" + - "\"attributes\":{}," + - "\"feature_names\":[]," + - "\"feature_types\":[]," + - "\"gradient_booster\":{" + - "\"model\":{" + - "\"gbtree_model_param\":{" + - "\"num_parallel_tree\":\"1\"," + - "\"num_trees\":\"1\"}," + - "\"iteration_indptr\":[0,1]," + - "\"tree_info\":[0]," + - "\"trees\":[{" + - "\"base_weights\":[-0E0]," + - "\"categories\":[]," + - "\"categories_nodes\":[]," + - "\"categories_segments\":[]," + - "\"categories_sizes\":[]," + - "\"default_left\":[0]," + - "\"id\":0," + - "\"left_children\":[-1]," + - "\"loss_changes\":[0E0]," + - "\"parents\":[2147483647]," + - "\"right_children\":[-1]," + - "\"split_conditions\":[-0E0]," + - "\"split_indices\":[0]," + - "\"split_type\":[0]," + - "\"sum_hessian\":[1E0]," + - "\"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"2\",\"num_nodes\":\"1\",\"size_leaf_vector\":\"1\"}}" + - "]}," + - "\"name\":\"gbtree\"" + - "}," + - "\"learner_model_param\":{" + - "\"base_score\":\"5E-1\"," + - "\"boost_from_average\":\"1\"," + - "\"num_class\":\"0\"," + - "\"num_feature\":\"2\"," + - "\"num_target\":\"1\"}," + + "{\"learner\":" + + "{" + + "\"attributes\":{}," + + "\"feature_names\":[]," + + "\"feature_types\":[]," + + "\"gradient_booster\":{" + + "\"model\":{" + + "\"gbtree_model_param\":{" + + "\"num_parallel_tree\":\"1\"," + + "\"num_trees\":\"1\"}," + + "\"iteration_indptr\":[0,1]," + + "\"tree_info\":[0]," + + "\"trees\":[{" + + "\"base_weights\":[1E0, 10E0, 0E0]," + + "\"categories\":[]," + + "\"categories_nodes\":[]," + + "\"categories_segments\":[]," + + "\"categories_sizes\":[]," + + "\"default_left\":[0, 0, 0]," + + "\"id\":0," + + "\"left_children\":[2, -1, -1]," + + "\"loss_changes\":[0E0, 0E0, 0E0]," + + "\"parents\":[2147483647, 0, 0]," + + "\"right_children\":[1, -1, -1]," + + "\"split_conditions\":[3E0, -1E0, -1E0]," + + "\"split_indices\":[0, 0, 0]," + + "\"split_type\":[0, 0, 0]," + + "\"sum_hessian\":[1E0, 1E0, 1E0]," + + "\"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + "]}," + + "\"name\":\"gbtree\"" + + "}," + + "\"learner_model_param\":{" + + "\"base_score\":\"5E-1\"," + + "\"boost_from_average\":\"1\"," + + "\"num_class\":\"0\"," + + "\"num_feature\":\"2\"," + + "\"num_target\":\"1\"" + + "}," + "\"objective\":{" + - "\"name\":\"binary:logistic\"," + + "\"name\":\"reg:linear\"," + "\"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + - "}" + - "}," + - "\"version\":[2,1,0]}"; - FeatureSet set = randomFeatureSet(); + "}" + + "}," + + "\"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); NaiveAdditiveDecisionTree tree = parser.parse(set, model); - FeatureVector featureVector = new SparseFeatureVector(2); + FeatureVector featureVector = new SparseFeatureVector(1); featureVector.setFeatureScore(0, 2); - featureVector.setFeatureScore(1, 3); assertEquals(0.0, tree.score(featureVector), Math.ulp(0.1F)); - } - private String readModel(String model) throws IOException { - try (InputStream is = this.getClass().getResourceAsStream(model)) { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - Streams.copy(is.readAllBytes(), bos); - return bos.toString(StandardCharsets.UTF_8.name()); - } + featureVector.setFeatureScore(0, 4); + assertEquals(10.0, tree.score(featureVector), Math.ulp(0.1F)); } + + // todo: more tests } From 30921864ab3017fd158fe91df9aad820143d913e Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Tue, 30 Jul 2024 20:42:14 +0200 Subject: [PATCH 03/10] reimplement using ObjectParser --- .../ranker/parser/XGBoostJsonParserV2.java | 472 +++++++++++------- 1 file changed, 281 insertions(+), 191 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java index 2dacb2b4..639c55e9 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java @@ -4,8 +4,8 @@ import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; import com.o19s.es.ltr.ranker.normalizer.Normalizer; import com.o19s.es.ltr.ranker.normalizer.Normalizers; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.xcontent.*; import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; @@ -19,259 +19,349 @@ public class XGBoostJsonParserV2 implements LtrRankerParser { @Override public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { - XGBoostLearner modelDefinition; + XGBoostJsonParserV2.XGBoostDefinition modelDefinition; try (XContentParser parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, model) ) { - modelDefinition = new XGBoostLearner(set, parser.map(), parser); + modelDefinition = XGBoostJsonParserV2.XGBoostDefinition.parse(parser, set); } catch (IOException e) { - throw new IllegalArgumentException("Unable to parse XGBoost object", e); + throw new IllegalArgumentException("Cannot parse model", e); } - NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getTrees(set); + NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getLearner().getTrees(set); float[] weights = new float[trees.length]; Arrays.fill(weights, 1F); - return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.normalizer); + return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.getLearner().getObjective().getNormalizer()); } - enum SplitType { - NUMERICAL(0), - CATEGORICAL(1); + private static class XGBoostDefinition { + private static final ObjectParser PARSER; - private final int value; - - SplitType(int value) { - this.value = value; + static { + PARSER = new ObjectParser<>("xgboost_definition", true, XGBoostJsonParserV2.XGBoostDefinition::new); + PARSER.declareObject(XGBoostJsonParserV2.XGBoostDefinition::setLearner, XGBoostJsonParserV2.XGBoostLearner::parse, new ParseField("learner")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostDefinition::setVersion, new ParseField("version")); } - public static SplitType fromValue(int value) { - for (SplitType type : values()) { - if (type.value == value) { - return type; + public static XGBoostJsonParserV2.XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException { + XGBoostJsonParserV2.XGBoostDefinition definition; + XContentParser.Token startToken = parser.nextToken(); + + if (startToken == XContentParser.Token.START_OBJECT) { + try { + definition = PARSER.apply(parser, set); + } catch (XContentParseException e) { + throw new ParsingException(parser.getTokenLocation(), "Unable to parse XGBoost object", e); } + if (definition.learner == null) { + throw new ParsingException(parser.getTokenLocation(), "XGBoost model missing required field [learner]"); + } + } else { + throw new ParsingException(parser.getTokenLocation(), "Expected [START_ARRAY] or [START_OBJECT] but got [" + + startToken + "]"); } - throw new IllegalArgumentException("Unknown SplitType value: " + value); + return definition; } - } - class Node { - int nodeid; - int left; - int right; - int parent; - int splitIdx; - float splitCond; - boolean defaultLeft; - SplitType splitType; - List categories; - float baseWeight; - float lossChg; - float sumHess; - - Node(int nodeid, int left, int right, int parent, int splitIdx, float splitCond, boolean defaultLeft, - SplitType splitType, List categories, float baseWeight, float lossChg, float sumHess) { - this.nodeid = nodeid; - this.left = left; - this.right = right; - this.parent = parent; - this.splitIdx = splitIdx; - this.splitCond = splitCond; - this.defaultLeft = defaultLeft; - this.splitType = splitType; - this.categories = categories; - this.baseWeight = baseWeight; - this.lossChg = lossChg; - this.sumHess = sumHess; + private XGBoostLearner learner; + + public XGBoostLearner getLearner() { + return learner; + } + + public void setLearner(XGBoostLearner learner) { + this.learner = learner; + } + + private List version; + + public List getVersion() { + return version; + } + + public void setVersion(List version) { + this.version = version; } } - class Tree { - int treeId; - List nodes; + static class XGBoostObjective { + private String name; - Tree(int treeId, List nodes) { - this.treeId = treeId; - this.nodes = nodes; + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_objective", true, XGBoostJsonParserV2.XGBoostObjective::new); + PARSER.declareString(XGBoostJsonParserV2.XGBoostObjective::setName, new ParseField("name")); + } + + public static XGBoostJsonParserV2.XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); } - float lossChange(int nodeId) { - return nodes.get(nodeId).lossChg; + public XGBoostObjective() { } - float sumHessian(int nodeId) { - return nodes.get(nodeId).sumHess; + public XGBoostObjective(String name) { + this.name = name; } - float baseWeight(int nodeId) { - return nodes.get(nodeId).baseWeight; + public String getName() { + return name; } - int splitIndex(int nodeId) { - return nodes.get(nodeId).splitIdx; + public void setName(String name) { + this.name = name; } - float splitCondition(int nodeId) { - return nodes.get(nodeId).splitCond; + Normalizer getNormalizer() { + switch (this.name) { + case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> { + return Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + } + case "binary:logistic", "reg:logistic" -> { + return Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); + } + default -> + throw new IllegalArgumentException("Objective [" + name + "] is not a valid XGBoost objective"); + } } + } + + static class XGBoostGradientBooster { + private XGBoostModel model; - List splitCategories(int nodeId) { - return nodes.get(nodeId).categories; + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_gradient_booster", true, XGBoostJsonParserV2.XGBoostGradientBooster::new); + PARSER.declareObject(XGBoostJsonParserV2.XGBoostGradientBooster::setModel, XGBoostJsonParserV2.XGBoostModel::parse, new ParseField("model")); } - boolean isCategorical(int nodeId) { - return nodes.get(nodeId).splitType == SplitType.CATEGORICAL; + public static XGBoostJsonParserV2.XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); } - boolean isNumerical(int nodeId) { - return !isCategorical(nodeId); + public XGBoostGradientBooster() { } - int parent(int nodeId) { - return nodes.get(nodeId).parent; + public XGBoostModel getModel() { + return model; } - int leftChild(int nodeId) { - return nodes.get(nodeId).left; + public void setModel(XGBoostModel model) { + this.model = model; } + } + + + static class XGBoostModel { + private List trees; - int rightChild(int nodeId) { - return nodes.get(nodeId).right; + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_model", true, XGBoostJsonParserV2.XGBoostModel::new); + PARSER.declareObjectArray(XGBoostJsonParserV2.XGBoostModel::setTrees, XGBoostJsonParserV2.XGBoostTree::parse, new ParseField("trees")); } - boolean isLeaf(int nodeId) { - return nodes.get(nodeId).left == -1 && nodes.get(nodeId).right == -1; + public static XGBoostJsonParserV2.XGBoostModel parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); } - boolean isSplit(int nodeId) { - return !this.isLeaf(nodeId); + public XGBoostModel() { } - boolean isDeleted(int nodeId) { - return splitIndex(nodeId) == MISSING_NODE_ID; + public List getTrees() { + return trees; } - NaiveAdditiveDecisionTree.Node toLibNode(int nodeid) { - if (isSplit(nodeid)) { - Node node = nodes.get(nodeid); - return new NaiveAdditiveDecisionTree.Split(toLibNode(node.left), toLibNode(node.right), - node.splitIdx, node.splitCond, node.left, MISSING_NODE_ID); - } else { - Node node = nodes.get(nodeid); - return new NaiveAdditiveDecisionTree.Leaf(node.baseWeight); - } + public void setTrees(List trees) { + this.trees = trees; } } - class XGBoostLearner { - - int numOutputGroup; - int numFeature; - float baseScore; - List treeInfo; - List trees; - Normalizer normalizer; - - XGBoostLearner(FeatureSet set, Map modelObj, XContentParser parser) { - Map learner = (Map) modelObj.get("learner"); - Map learnerModelShape = (Map) learner.get("learner_model_param"); - this.numOutputGroup = Integer.parseInt(learnerModelShape.get("num_class")); - this.numFeature = Integer.parseInt(learnerModelShape.get("num_feature")); - this.baseScore = Float.parseFloat(learnerModelShape.get("base_score")); - - String normalizerName = (String) ((Map) learner.get("objective")).get("name"); - this.normalizer = getNormalizer(normalizerName); - - Map gradientBooster = (Map) learner.get("gradient_booster"); - this.treeInfo = (List) gradientBooster.get("tree_info"); - Map model = (Map) gradientBooster.get("model"); - Map modelShape = (Map) model.get("gbtree_model_param"); - - List> treesObj = (List>) model.get("trees"); - this.trees = new ArrayList<>(); - int numTrees = Integer.parseInt(modelShape.get("num_trees")); - - for (int i = 0; i < numTrees; i++) { - Map tree = treesObj.get(i); - int treeId = (int) tree.get("id"); - - List leftChildren = (List) tree.get("left_children"); - List rightChildren = (List) tree.get("right_children"); - List parents = (List) tree.get("parents"); - List splitConditions = ((List) tree.get("split_conditions")).stream().map(Double::floatValue).toList(); - List splitIndices = (List) tree.get("split_indices"); - - List defaultLeft = (List) tree.get("default_left"); - List splitTypes = (List) tree.get("split_type"); - - List catSegments = (List) tree.get("categories_segments"); - List catSizes = (List) tree.get("categories_sizes"); - List catNodes = (List) tree.get("categories_nodes"); - List cats = (List) tree.get("categories"); - - int catCnt = 0; - int lastCatNode = !catNodes.isEmpty() ? catNodes.get(catCnt) : -1; - List> nodeCategories = new ArrayList<>(); - - for (int nodeId = 0; nodeId < leftChildren.size(); nodeId++) { - if (nodeId == lastCatNode) { - int beg = catSegments.get(catCnt); - int size = catSizes.get(catCnt); - int end = beg + size; - List nodeCats = cats.subList(beg, end); - catCnt++; - lastCatNode = catCnt < catNodes.size() ? catNodes.get(catCnt) : -1; - nodeCategories.add(nodeCats); - } else { - nodeCategories.add(new ArrayList<>()); - } - } + static class XGBoostLearner { - List baseWeights = ((List) tree.get("base_weights")).stream().map(Double::floatValue).toList(); - List lossChanges = ((List) tree.get("loss_changes")).stream().map(Double::floatValue).toList(); - List sumHessian = ((List) tree.get("sum_hessian")).stream().map(Double::floatValue).toList(); - - List nodes = new ArrayList<>(); - for (int nodeId = 0; nodeId < leftChildren.size(); nodeId++) { - nodes.add(new Node( - nodeId, - leftChildren.get(nodeId), - rightChildren.get(nodeId), - parents.get(nodeId), - splitIndices.get(nodeId), - splitConditions.get(nodeId), - defaultLeft.get(nodeId) == 1, - SplitType.fromValue(splitTypes.get(nodeId)), - nodeCategories.get(nodeId), - baseWeights.get(nodeId), - lossChanges.get(nodeId), - sumHessian.get(nodeId) - )); - } + // private int numOutputGroup; +// int numFeature; +// float baseScore; + private List treeInfo; + private XGBoostGradientBooster gradientBooster; + private XGBoostObjective objective; - trees.add(new Tree(treeId, nodes)); - } + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostJsonParserV2.XGBoostLearner::new); + PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setObjective, XGBoostJsonParserV2.XGBoostObjective::parse, new ParseField("objective")); + PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setGradientBooster, XGBoostJsonParserV2.XGBoostGradientBooster::parse, new ParseField("gradient_booster")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostLearner::setTreeInfo, new ParseField("tree_info")); + } + + public static XGBoostJsonParserV2.XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); + } + + XGBoostLearner() { } NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { - NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[this.trees.size()]; - ListIterator it = this.trees.listIterator(); + List parsedTrees = this.getGradientBooster().getModel().getTrees(); + NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()]; + ListIterator it = parsedTrees.listIterator(); while (it.hasNext()) { - trees[it.nextIndex()] = it.next().toLibNode(0); + trees[it.nextIndex()] = it.next().asLibTree(); } return trees; } - Normalizer getNormalizer(String objectiveName) { - switch (objectiveName) { - case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> { - return Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); - } - case "binary:logistic", "reg:logistic" -> { - return Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); - } - default -> - throw new IllegalArgumentException("Objective [" + objectiveName + "] is not a valid XGBoost objective"); + + public XGBoostObjective getObjective() { + return objective; + } + + public void setObjective(XGBoostObjective objective) { + this.objective = objective; + } + + public List getTreeInfo() { + return treeInfo; + } + + public void setTreeInfo(List treeInfo) { + this.treeInfo = treeInfo; + } + + public XGBoostGradientBooster getGradientBooster() { + return gradientBooster; + } + + public void setGradientBooster(XGBoostGradientBooster gradientBooster) { + this.gradientBooster = gradientBooster; + } + } + + static class XGBoostTree { + private Integer treeId; + private List leftChildren; + private List rightChildren; + private List parents; + private List splitConditions; + private List splitIndices; + private List defaultLeft; + private List splitTypes; + private List baseWeights; + + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_tree", true, XGBoostJsonParserV2.XGBoostTree::new); + PARSER.declareInt(XGBoostJsonParserV2.XGBoostTree::setTreeId, new ParseField("id")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setLeftChildren, new ParseField("left_children")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setRightChildren, new ParseField("right_children")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setParents, new ParseField("parents")); + PARSER.declareFloatArray(XGBoostJsonParserV2.XGBoostTree::setSplitConditions, new ParseField("split_conditions")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setSplitIndices, new ParseField("split_indices")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setDefaultLeft, new ParseField("default_left")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setSplitTypes, new ParseField("split_type")); +// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCatSegments, new ParseField("categories_segments")); +// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCatSizes, new ParseField("categories_sizes")); +// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCatNodes, new ParseField("categories_nodes")); +// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCats, new ParseField("categories")); + PARSER.declareFloatArray(XGBoostJsonParserV2.XGBoostTree::setBaseWeights, new ParseField("base_weights")); + } + + public static XGBoostJsonParserV2.XGBoostTree parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); + } + + public Integer getTreeId() { + return treeId; + } + + public void setTreeId(Integer treeId) { + this.treeId = treeId; + } + + public List getLeftChildren() { + return leftChildren; + } + + public void setLeftChildren(List leftChildren) { + this.leftChildren = leftChildren; + } + + public List getRightChildren() { + return rightChildren; + } + + public void setRightChildren(List rightChildren) { + this.rightChildren = rightChildren; + } + + public List getParents() { + return parents; + } + + public void setParents(List parents) { + this.parents = parents; + } + + public List getSplitConditions() { + return splitConditions; + } + + public void setSplitConditions(List splitConditions) { + this.splitConditions = splitConditions; + } + + public List getSplitIndices() { + return splitIndices; + } + + public void setSplitIndices(List splitIndices) { + this.splitIndices = splitIndices; + } + + public List getDefaultLeft() { + return defaultLeft; + } + + public void setDefaultLeft(List defaultLeft) { + this.defaultLeft = defaultLeft; + } + + public List getSplitTypes() { + return splitTypes; + } + + public void setSplitTypes(List splitTypes) { + this.splitTypes = splitTypes; + } + + public NaiveAdditiveDecisionTree.Node asLibTree() { + return this.asLibTree(0); + } + + private boolean isSplit(Integer nodeId) { + return leftChildren.get(nodeId) != -1 && rightChildren.get(nodeId) != -1; + } + + private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) { + if (isSplit(nodeId)) { + return new NaiveAdditiveDecisionTree.Split(asLibTree(leftChildren.get(nodeId)), asLibTree(rightChildren.get(nodeId)), + splitIndices.get(nodeId), splitConditions.get(nodeId), splitIndices.get(nodeId), MISSING_NODE_ID); + } else { + return new NaiveAdditiveDecisionTree.Leaf(baseWeights.get(nodeId)); } } + + public List getBaseWeights() { + return baseWeights; + } + + public void setBaseWeights(List baseWeights) { + this.baseWeights = baseWeights; + } } -} \ No newline at end of file +} From e9d660f83006a489942322d4c3aba611c4f6c7c9 Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Tue, 30 Jul 2024 20:44:36 +0200 Subject: [PATCH 04/10] refactor a bit --- .../ranker/parser/XGBoostJsonParserV2.java | 134 ++++++++---------- 1 file changed, 63 insertions(+), 71 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java index 639c55e9..1b299abf 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java @@ -84,46 +84,61 @@ public void setVersion(List version) { } } - static class XGBoostObjective { - private String name; + static class XGBoostLearner { - private static final ObjectParser PARSER; + private List treeInfo; + private XGBoostGradientBooster gradientBooster; + private XGBoostObjective objective; + + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_objective", true, XGBoostJsonParserV2.XGBoostObjective::new); - PARSER.declareString(XGBoostJsonParserV2.XGBoostObjective::setName, new ParseField("name")); + PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostJsonParserV2.XGBoostLearner::new); + PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setObjective, XGBoostJsonParserV2.XGBoostObjective::parse, new ParseField("objective")); + PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setGradientBooster, XGBoostJsonParserV2.XGBoostGradientBooster::parse, new ParseField("gradient_booster")); + PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostLearner::setTreeInfo, new ParseField("tree_info")); } - public static XGBoostJsonParserV2.XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException { + public static XGBoostJsonParserV2.XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException { return PARSER.apply(parser, set); } - public XGBoostObjective() { + XGBoostLearner() { } - public XGBoostObjective(String name) { - this.name = name; + NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { + List parsedTrees = this.getGradientBooster().getModel().getTrees(); + NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()]; + ListIterator it = parsedTrees.listIterator(); + while (it.hasNext()) { + trees[it.nextIndex()] = it.next().asLibTree(); + } + return trees; } - public String getName() { - return name; + + public XGBoostObjective getObjective() { + return objective; } - public void setName(String name) { - this.name = name; + public void setObjective(XGBoostObjective objective) { + this.objective = objective; } - Normalizer getNormalizer() { - switch (this.name) { - case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> { - return Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); - } - case "binary:logistic", "reg:logistic" -> { - return Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); - } - default -> - throw new IllegalArgumentException("Objective [" + name + "] is not a valid XGBoost objective"); - } + public List getTreeInfo() { + return treeInfo; + } + + public void setTreeInfo(List treeInfo) { + this.treeInfo = treeInfo; + } + + public XGBoostGradientBooster getGradientBooster() { + return gradientBooster; + } + + public void setGradientBooster(XGBoostGradientBooster gradientBooster) { + this.gradientBooster = gradientBooster; } } @@ -153,7 +168,6 @@ public void setModel(XGBoostModel model) { } } - static class XGBoostModel { private List trees; @@ -180,64 +194,46 @@ public void setTrees(List trees) { } } - static class XGBoostLearner { - - // private int numOutputGroup; -// int numFeature; -// float baseScore; - private List treeInfo; - private XGBoostGradientBooster gradientBooster; - private XGBoostObjective objective; + static class XGBoostObjective { + private String name; - private static final ObjectParser PARSER; + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostJsonParserV2.XGBoostLearner::new); - PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setObjective, XGBoostJsonParserV2.XGBoostObjective::parse, new ParseField("objective")); - PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setGradientBooster, XGBoostJsonParserV2.XGBoostGradientBooster::parse, new ParseField("gradient_booster")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostLearner::setTreeInfo, new ParseField("tree_info")); + PARSER = new ObjectParser<>("xgboost_objective", true, XGBoostJsonParserV2.XGBoostObjective::new); + PARSER.declareString(XGBoostJsonParserV2.XGBoostObjective::setName, new ParseField("name")); } - public static XGBoostJsonParserV2.XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException { + public static XGBoostJsonParserV2.XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException { return PARSER.apply(parser, set); } - XGBoostLearner() { - } - - NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { - List parsedTrees = this.getGradientBooster().getModel().getTrees(); - NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()]; - ListIterator it = parsedTrees.listIterator(); - while (it.hasNext()) { - trees[it.nextIndex()] = it.next().asLibTree(); - } - return trees; - } - - - public XGBoostObjective getObjective() { - return objective; - } - - public void setObjective(XGBoostObjective objective) { - this.objective = objective; + public XGBoostObjective() { } - public List getTreeInfo() { - return treeInfo; + public XGBoostObjective(String name) { + this.name = name; } - public void setTreeInfo(List treeInfo) { - this.treeInfo = treeInfo; + public String getName() { + return name; } - public XGBoostGradientBooster getGradientBooster() { - return gradientBooster; + public void setName(String name) { + this.name = name; } - public void setGradientBooster(XGBoostGradientBooster gradientBooster) { - this.gradientBooster = gradientBooster; + Normalizer getNormalizer() { + switch (this.name) { + case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> { + return Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + } + case "binary:logistic", "reg:logistic" -> { + return Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); + } + default -> + throw new IllegalArgumentException("Objective [" + name + "] is not a valid XGBoost objective"); + } } } @@ -264,10 +260,6 @@ static class XGBoostTree { PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setSplitIndices, new ParseField("split_indices")); PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setDefaultLeft, new ParseField("default_left")); PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setSplitTypes, new ParseField("split_type")); -// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCatSegments, new ParseField("categories_segments")); -// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCatSizes, new ParseField("categories_sizes")); -// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCatNodes, new ParseField("categories_nodes")); -// PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setCats, new ParseField("categories")); PARSER.declareFloatArray(XGBoostJsonParserV2.XGBoostTree::setBaseWeights, new ParseField("base_weights")); } From bcba0d72c6e8a263268fafbc1c664715c4bad83d Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Sat, 5 Oct 2024 17:58:26 +0200 Subject: [PATCH 05/10] add tests + minor fixes / refactoring --- .../com/o19s/es/ltr/LtrQueryParserPlugin.java | 2 + ...arserV2.java => XGBoostRawJsonParser.java} | 201 +++++---- .../parser/XGBoostJsonParserV2Tests.java | 84 ---- .../parser/XGBoostRawJsonParserTests.java | 423 ++++++++++++++++++ 4 files changed, 538 insertions(+), 172 deletions(-) rename src/main/java/com/o19s/es/ltr/ranker/parser/{XGBoostJsonParserV2.java => XGBoostRawJsonParser.java} (54%) delete mode 100644 src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java create mode 100644 src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java diff --git a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java index dc86730b..772b37d7 100644 --- a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java +++ b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java @@ -47,6 +47,7 @@ import com.o19s.es.ltr.ranker.parser.LinearRankerParser; import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; import com.o19s.es.ltr.ranker.parser.XGBoostJsonParser; +import com.o19s.es.ltr.ranker.parser.XGBoostRawJsonParser; import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; import com.o19s.es.ltr.ranker.ranklib.RanklibModelParser; import com.o19s.es.ltr.rest.RestCreateModelFromSet; @@ -129,6 +130,7 @@ public LtrQueryParserPlugin(Settings settings) { .register(RanklibModelParser.TYPE, () -> new RanklibModelParser(ranklib.get())) .register(LinearRankerParser.TYPE, LinearRankerParser::new) .register(XGBoostJsonParser.TYPE, XGBoostJsonParser::new) + .register(XGBoostRawJsonParser.TYPE, XGBoostRawJsonParser::new) .build(); } diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java similarity index 54% rename from src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java rename to src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java index 1b299abf..3dcc68b5 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java @@ -11,19 +11,19 @@ import java.io.IOException; import java.util.*; -public class XGBoostJsonParserV2 implements LtrRankerParser { +public class XGBoostRawJsonParser implements LtrRankerParser { - public static final String TYPE = "model/xgboost+json+v2"; + public static final String TYPE = "model/xgboost+json+raw"; private static final Integer MISSING_NODE_ID = Integer.MAX_VALUE; @Override public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { - XGBoostJsonParserV2.XGBoostDefinition modelDefinition; + XGBoostRawJsonParser.XGBoostDefinition modelDefinition; try (XContentParser parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, model) ) { - modelDefinition = XGBoostJsonParserV2.XGBoostDefinition.parse(parser, set); + modelDefinition = XGBoostRawJsonParser.XGBoostDefinition.parse(parser, set); } catch (IOException e) { throw new IllegalArgumentException("Cannot parse model", e); } @@ -35,16 +35,16 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { } private static class XGBoostDefinition { - private static final ObjectParser PARSER; + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_definition", true, XGBoostJsonParserV2.XGBoostDefinition::new); - PARSER.declareObject(XGBoostJsonParserV2.XGBoostDefinition::setLearner, XGBoostJsonParserV2.XGBoostLearner::parse, new ParseField("learner")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostDefinition::setVersion, new ParseField("version")); + PARSER = new ObjectParser<>("xgboost_definition", true, XGBoostRawJsonParser.XGBoostDefinition::new); + PARSER.declareObject(XGBoostRawJsonParser.XGBoostDefinition::setLearner, XGBoostRawJsonParser.XGBoostLearner::parse, new ParseField("learner")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostDefinition::setVersion, new ParseField("version")); } - public static XGBoostJsonParserV2.XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException { - XGBoostJsonParserV2.XGBoostDefinition definition; + public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException { + XGBoostRawJsonParser.XGBoostDefinition definition; XContentParser.Token startToken = parser.nextToken(); if (startToken == XContentParser.Token.START_OBJECT) { @@ -56,9 +56,17 @@ public static XGBoostJsonParserV2.XGBoostDefinition parse(XContentParser parser, if (definition.learner == null) { throw new ParsingException(parser.getTokenLocation(), "XGBoost model missing required field [learner]"); } + List unknownFeatures = new ArrayList<>(); + for (String modelFeatureName : definition.learner.featureNames) { + if (!set.hasFeature(modelFeatureName)) { + unknownFeatures.add(modelFeatureName); + } + } + if (!unknownFeatures.isEmpty()) { + throw new ParsingException(parser.getTokenLocation(), "Unknown features in model: [" + String.join(", ", unknownFeatures) + "]"); + } } else { - throw new ParsingException(parser.getTokenLocation(), "Expected [START_ARRAY] or [START_OBJECT] but got [" - + startToken + "]"); + throw new ParsingException(parser.getTokenLocation(), "Expected [START_OBJECT] but got [" + startToken + "]"); } return definition; } @@ -86,20 +94,30 @@ public void setVersion(List version) { static class XGBoostLearner { - private List treeInfo; + private List featureNames; + private List featureTypes; private XGBoostGradientBooster gradientBooster; private XGBoostObjective objective; - private static final ObjectParser PARSER; + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostJsonParserV2.XGBoostLearner::new); - PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setObjective, XGBoostJsonParserV2.XGBoostObjective::parse, new ParseField("objective")); - PARSER.declareObject(XGBoostJsonParserV2.XGBoostLearner::setGradientBooster, XGBoostJsonParserV2.XGBoostGradientBooster::parse, new ParseField("gradient_booster")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostLearner::setTreeInfo, new ParseField("tree_info")); + PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostRawJsonParser.XGBoostLearner::new); + PARSER.declareObject(XGBoostRawJsonParser.XGBoostLearner::setObjective, XGBoostRawJsonParser.XGBoostObjective::parse, new ParseField("objective")); + PARSER.declareObject(XGBoostRawJsonParser.XGBoostLearner::setGradientBooster, XGBoostRawJsonParser.XGBoostGradientBooster::parse, new ParseField("gradient_booster")); + PARSER.declareStringArray(XGBoostRawJsonParser.XGBoostLearner::setFeatureNames, new ParseField("feature_names")); + PARSER.declareStringArray(XGBoostRawJsonParser.XGBoostLearner::setFeatureTypes, new ParseField("feature_types")); + } + + private void setFeatureTypes(List featureTypes) { + this.featureTypes = featureTypes; + } + + private void setFeatureNames(List featureNames) { + this.featureNames = featureNames; } - public static XGBoostJsonParserV2.XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException { + public static XGBoostRawJsonParser.XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException { return PARSER.apply(parser, set); } @@ -107,13 +125,7 @@ public static XGBoostJsonParserV2.XGBoostLearner parse(XContentParser parser, Fe } NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { - List parsedTrees = this.getGradientBooster().getModel().getTrees(); - NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()]; - ListIterator it = parsedTrees.listIterator(); - while (it.hasNext()) { - trees[it.nextIndex()] = it.next().asLibTree(); - } - return trees; + return this.getGradientBooster().getModel().getTrees(); } @@ -125,14 +137,6 @@ public void setObjective(XGBoostObjective objective) { this.objective = objective; } - public List getTreeInfo() { - return treeInfo; - } - - public void setTreeInfo(List treeInfo) { - this.treeInfo = treeInfo; - } - public XGBoostGradientBooster getGradientBooster() { return gradientBooster; } @@ -145,14 +149,14 @@ public void setGradientBooster(XGBoostGradientBooster gradientBooster) { static class XGBoostGradientBooster { private XGBoostModel model; - private static final ObjectParser PARSER; + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_gradient_booster", true, XGBoostJsonParserV2.XGBoostGradientBooster::new); - PARSER.declareObject(XGBoostJsonParserV2.XGBoostGradientBooster::setModel, XGBoostJsonParserV2.XGBoostModel::parse, new ParseField("model")); + PARSER = new ObjectParser<>("xgboost_gradient_booster", true, XGBoostRawJsonParser.XGBoostGradientBooster::new); + PARSER.declareObject(XGBoostRawJsonParser.XGBoostGradientBooster::setModel, XGBoostRawJsonParser.XGBoostModel::parse, new ParseField("model")); } - public static XGBoostJsonParserV2.XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException { + public static XGBoostRawJsonParser.XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException { return PARSER.apply(parser, set); } @@ -169,72 +173,82 @@ public void setModel(XGBoostModel model) { } static class XGBoostModel { - private List trees; + private NaiveAdditiveDecisionTree.Node[] trees; + private List treeInfo; - private static final ObjectParser PARSER; + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_model", true, XGBoostJsonParserV2.XGBoostModel::new); - PARSER.declareObjectArray(XGBoostJsonParserV2.XGBoostModel::setTrees, XGBoostJsonParserV2.XGBoostTree::parse, new ParseField("trees")); + PARSER = new ObjectParser<>("xgboost_model", true, XGBoostRawJsonParser.XGBoostModel::new); + PARSER.declareObjectArray(XGBoostRawJsonParser.XGBoostModel::setTrees, XGBoostRawJsonParser.XGBoostTree::parse, new ParseField("trees")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostModel::setTreeInfo, new ParseField("tree_info")); } - public static XGBoostJsonParserV2.XGBoostModel parse(XContentParser parser, FeatureSet set) throws IOException { - return PARSER.apply(parser, set); + public List getTreeInfo() { + return treeInfo; + } + + public void setTreeInfo(List treeInfo) { + this.treeInfo = treeInfo; + } + + public static XGBoostRawJsonParser.XGBoostModel parse(XContentParser parser, FeatureSet set) throws IOException { + try { + return PARSER.apply(parser, set); + } catch (IllegalArgumentException e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } } public XGBoostModel() { } - public List getTrees() { + public NaiveAdditiveDecisionTree.Node[] getTrees() { return trees; } - public void setTrees(List trees) { + public void setTrees(List parsedTrees) { + NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()]; + ListIterator it = parsedTrees.listIterator(); + while (it.hasNext()) { + trees[it.nextIndex()] = it.next().getRootNode(); + } this.trees = trees; } } static class XGBoostObjective { - private String name; + private Normalizer normalizer; - private static final ObjectParser PARSER; + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_objective", true, XGBoostJsonParserV2.XGBoostObjective::new); - PARSER.declareString(XGBoostJsonParserV2.XGBoostObjective::setName, new ParseField("name")); + PARSER = new ObjectParser<>("xgboost_objective", true, XGBoostRawJsonParser.XGBoostObjective::new); + PARSER.declareString(XGBoostRawJsonParser.XGBoostObjective::setName, new ParseField("name")); } - public static XGBoostJsonParserV2.XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException { + public static XGBoostRawJsonParser.XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException { return PARSER.apply(parser, set); } public XGBoostObjective() { } - public XGBoostObjective(String name) { - this.name = name; - } - - public String getName() { - return name; - } public void setName(String name) { - this.name = name; - } - - Normalizer getNormalizer() { - switch (this.name) { - case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> { - return Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); - } - case "binary:logistic", "reg:logistic" -> { - return Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); - } + switch (name) { + case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> + this.normalizer = Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + case "binary:logistic", "reg:logistic" -> + this.normalizer = Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); default -> throw new IllegalArgumentException("Objective [" + name + "] is not a valid XGBoost objective"); } } + + Normalizer getNormalizer() { + return this.normalizer; + } } static class XGBoostTree { @@ -248,23 +262,27 @@ static class XGBoostTree { private List splitTypes; private List baseWeights; - private static final ObjectParser PARSER; + private NaiveAdditiveDecisionTree.Node rootNode; + + private static final ObjectParser PARSER; static { - PARSER = new ObjectParser<>("xgboost_tree", true, XGBoostJsonParserV2.XGBoostTree::new); - PARSER.declareInt(XGBoostJsonParserV2.XGBoostTree::setTreeId, new ParseField("id")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setLeftChildren, new ParseField("left_children")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setRightChildren, new ParseField("right_children")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setParents, new ParseField("parents")); - PARSER.declareFloatArray(XGBoostJsonParserV2.XGBoostTree::setSplitConditions, new ParseField("split_conditions")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setSplitIndices, new ParseField("split_indices")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setDefaultLeft, new ParseField("default_left")); - PARSER.declareIntArray(XGBoostJsonParserV2.XGBoostTree::setSplitTypes, new ParseField("split_type")); - PARSER.declareFloatArray(XGBoostJsonParserV2.XGBoostTree::setBaseWeights, new ParseField("base_weights")); - } - - public static XGBoostJsonParserV2.XGBoostTree parse(XContentParser parser, FeatureSet set) throws IOException { - return PARSER.apply(parser, set); + PARSER = new ObjectParser<>("xgboost_tree", true, XGBoostRawJsonParser.XGBoostTree::new); + PARSER.declareInt(XGBoostRawJsonParser.XGBoostTree::setTreeId, new ParseField("id")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setLeftChildren, new ParseField("left_children")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setRightChildren, new ParseField("right_children")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setParents, new ParseField("parents")); + PARSER.declareFloatArray(XGBoostRawJsonParser.XGBoostTree::setSplitConditions, new ParseField("split_conditions")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setSplitIndices, new ParseField("split_indices")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setDefaultLeft, new ParseField("default_left")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setSplitTypes, new ParseField("split_type")); + PARSER.declareFloatArray(XGBoostRawJsonParser.XGBoostTree::setBaseWeights, new ParseField("base_weights")); + } + + public static XGBoostRawJsonParser.XGBoostTree parse(XContentParser parser, FeatureSet set) throws IOException { + XGBoostRawJsonParser.XGBoostTree tree = PARSER.apply(parser, set); + tree.rootNode = tree.asLibTree(0); + return tree; } public Integer getTreeId() { @@ -331,15 +349,18 @@ public void setSplitTypes(List splitTypes) { this.splitTypes = splitTypes; } - public NaiveAdditiveDecisionTree.Node asLibTree() { - return this.asLibTree(0); - } - private boolean isSplit(Integer nodeId) { return leftChildren.get(nodeId) != -1 && rightChildren.get(nodeId) != -1; } private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) { + if (nodeId >= leftChildren.size()) { + throw new IllegalArgumentException("Node ID [" + nodeId + "] is invalid"); + } + if (nodeId >= rightChildren.size()) { + throw new IllegalArgumentException("Node ID [" + nodeId + "] is invalid"); + } + if (isSplit(nodeId)) { return new NaiveAdditiveDecisionTree.Split(asLibTree(leftChildren.get(nodeId)), asLibTree(rightChildren.get(nodeId)), splitIndices.get(nodeId), splitConditions.get(nodeId), splitIndices.get(nodeId), MISSING_NODE_ID); @@ -355,5 +376,9 @@ public List getBaseWeights() { public void setBaseWeights(List baseWeights) { this.baseWeights = baseWeights; } + + public NaiveAdditiveDecisionTree.Node getRootNode() { + return rootNode; + } } } diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java deleted file mode 100644 index 54c2b75b..00000000 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserV2Tests.java +++ /dev/null @@ -1,84 +0,0 @@ -package com.o19s.es.ltr.ranker.parser; - -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.ranker.SparseFeatureVector; -import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; -import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.elasticsearch.common.io.Streams; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.sql.Array; - -import static com.o19s.es.ltr.LtrTestUtils.randomFeature; -import static com.o19s.es.ltr.LtrTestUtils.randomFeatureSet; -import static java.util.Collections.singletonList; - -public class XGBoostJsonParserV2Tests extends LuceneTestCase { - private final XGBoostJsonParserV2 parser = new XGBoostJsonParserV2(); - - public void testReadLeaf() throws IOException { - String model = - "{\"learner\":" + - "{" + - "\"attributes\":{}," + - "\"feature_names\":[]," + - "\"feature_types\":[]," + - "\"gradient_booster\":{" + - "\"model\":{" + - "\"gbtree_model_param\":{" + - "\"num_parallel_tree\":\"1\"," + - "\"num_trees\":\"1\"}," + - "\"iteration_indptr\":[0,1]," + - "\"tree_info\":[0]," + - "\"trees\":[{" + - "\"base_weights\":[1E0, 10E0, 0E0]," + - "\"categories\":[]," + - "\"categories_nodes\":[]," + - "\"categories_segments\":[]," + - "\"categories_sizes\":[]," + - "\"default_left\":[0, 0, 0]," + - "\"id\":0," + - "\"left_children\":[2, -1, -1]," + - "\"loss_changes\":[0E0, 0E0, 0E0]," + - "\"parents\":[2147483647, 0, 0]," + - "\"right_children\":[1, -1, -1]," + - "\"split_conditions\":[3E0, -1E0, -1E0]," + - "\"split_indices\":[0, 0, 0]," + - "\"split_type\":[0, 0, 0]," + - "\"sum_hessian\":[1E0, 1E0, 1E0]," + - "\"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + - "]}," + - "\"name\":\"gbtree\"" + - "}," + - "\"learner_model_param\":{" + - "\"base_score\":\"5E-1\"," + - "\"boost_from_average\":\"1\"," + - "\"num_class\":\"0\"," + - "\"num_feature\":\"2\"," + - "\"num_target\":\"1\"" + - "}," + - "\"objective\":{" + - "\"name\":\"reg:linear\"," + - "\"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + - "}" + - "}," + - "\"version\":[2,1,0]" + - "}"; - - FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); - NaiveAdditiveDecisionTree tree = parser.parse(set, model); - FeatureVector featureVector = new SparseFeatureVector(1); - featureVector.setFeatureScore(0, 2); - assertEquals(0.0, tree.score(featureVector), Math.ulp(0.1F)); - - featureVector.setFeatureScore(0, 4); - assertEquals(10.0, tree.score(featureVector), Math.ulp(0.1F)); - } - - // todo: more tests -} diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java new file mode 100644 index 00000000..fe129646 --- /dev/null +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java @@ -0,0 +1,423 @@ +package com.o19s.es.ltr.ranker.parser; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; +import com.o19s.es.ltr.ranker.SparseFeatureVector; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.common.ParsingException; +import org.hamcrest.CoreMatchers; +import org.junit.Rule; +import org.junit.rules.ExpectedException; + +import java.io.IOException; + +import static com.o19s.es.ltr.LtrTestUtils.randomFeature; +import static java.util.Collections.singletonList; + +public class XGBoostRawJsonParserTests extends LuceneTestCase { + private final XGBoostRawJsonParser parser = new XGBoostRawJsonParser(); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + public void testSimpleSplit() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector featureVector = new SparseFeatureVector(1); + featureVector.setFeatureScore(0, 2); + assertEquals(0.0, tree.score(featureVector), Math.ulp(0.1F)); + + featureVector.setFeatureScore(0, 4); + assertEquals(10.0, tree.score(featureVector), Math.ulp(0.1F)); + } + + public void testReadWithLogisticObjective() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, -2E-1, 5E-1]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:logistic\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector v = tree.newFeatureVector(null); + v.setFeatureScore(0, 2); + assertEquals(0.62245935F, tree.score(v), Math.ulp(0.62245935F)); + v.setFeatureScore(0, 4); + assertEquals(0.45016602F, tree.score(v), Math.ulp(0.45016602F)); + } + + public void testBadObjectiveParam() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"int\", \"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:invalid\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unable to parse XGBoost object")); + } + + public void testSplitMissingLeftChild() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[100, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + try { + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + parser.parse(set, model); + fail("Expected an exception"); + } catch (ParsingException e) { + assertThat(e.getMessage(), CoreMatchers.containsString("Unable to parse XGBoost object")); + Throwable rootCause = e.getCause().getCause().getCause().getCause().getCause().getCause(); + assertThat(rootCause, CoreMatchers.instanceOf(IllegalArgumentException.class)); + assertThat(rootCause.getMessage(), CoreMatchers.containsString("Node ID [100] is invalid")); + } + } + + public void testSplitMissingRightChild() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[1, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[100, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + try { + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + parser.parse(set, model); + fail("Expected an exception"); + } catch (ParsingException e) { + assertThat(e.getMessage(), CoreMatchers.containsString("Unable to parse XGBoost object")); + Throwable rootCause = e.getCause().getCause().getCause().getCause().getCause().getCause(); + assertThat(rootCause, CoreMatchers.instanceOf(IllegalArgumentException.class)); + assertThat(rootCause.getMessage(), CoreMatchers.containsString("Node ID [100] is invalid")); + } + } + + public void testBadStruct() throws IOException { + String model = + "[{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"int\", \"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}]"; + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Expected [START_OBJECT] but got")); + } + + public void testMissingFeat() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"int\",\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 100]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"2\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1234"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unknown features in model: [feat1, feat2]")); + } +} \ No newline at end of file From 3e717e809f34454813d1f232195f9198947f5ef7 Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Sat, 5 Oct 2024 18:09:21 +0200 Subject: [PATCH 06/10] add tests + minor fixes / refactoring --- .../ranker/parser/XGBoostRawJsonParserTests.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java index fe129646..36fa0882 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java @@ -28,7 +28,7 @@ public void testSimpleSplit() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\"]," + - " \"feature_types\":[\"int\"]," + + " \"feature_types\":[\"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + @@ -87,7 +87,7 @@ public void testReadWithLogisticObjective() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\"]," + - " \"feature_types\":[\"int\"]," + + " \"feature_types\":[\"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + @@ -145,7 +145,7 @@ public void testBadObjectiveParam() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\", \"feat2\"]," + - " \"feature_types\":[\"int\", \"int\"]," + + " \"feature_types\":[\"float\", \"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + @@ -199,7 +199,7 @@ public void testSplitMissingLeftChild() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\"]," + - " \"feature_types\":[\"int\"]," + + " \"feature_types\":[\"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + @@ -260,7 +260,7 @@ public void testSplitMissingRightChild() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\"]," + - " \"feature_types\":[\"int\"]," + + " \"feature_types\":[\"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + @@ -321,7 +321,7 @@ public void testBadStruct() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\", \"feat2\"]," + - " \"feature_types\":[\"int\", \"int\"]," + + " \"feature_types\":[\"float\", \"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + @@ -374,7 +374,7 @@ public void testMissingFeat() throws IOException { " \"learner\":{" + " \"attributes\":{}," + " \"feature_names\":[\"feat1\", \"feat2\"]," + - " \"feature_types\":[\"int\",\"int\"]," + + " \"feature_types\":[\"float\",\"float\"]," + " \"gradient_booster\":{" + " \"model\":{" + " \"gbtree_model_param\":{" + From 3991c643287e5f03f7c9ed683cc3820b93ca92ec Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Sat, 5 Oct 2024 18:32:22 +0200 Subject: [PATCH 07/10] additional validation --- .../ranker/parser/XGBoostRawJsonParser.java | 18 ++- .../parser/XGBoostRawJsonParserTests.java | 114 +++++++++++++++++- 2 files changed, 128 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java index 3dcc68b5..5004c57a 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java @@ -65,6 +65,20 @@ public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser if (!unknownFeatures.isEmpty()) { throw new ParsingException(parser.getTokenLocation(), "Unknown features in model: [" + String.join(", ", unknownFeatures) + "]"); } + if (definition.learner.featureNames.size() != definition.learner.featureTypes.size()) { + throw new ParsingException(parser.getTokenLocation(), + "Feature names list and feature types list must have the same length"); + } + Optional firstUnsupportedType = definition.learner.featureTypes.stream() + .filter(typeStr -> !typeStr.equals("float")) + .findFirst(); + if (firstUnsupportedType.isPresent()) { + throw new ParsingException(parser.getTokenLocation(), + "The LTR plugin only supports float feature types " + + "because Elasticsearch scores are always float32. " + + "Found feature type [" + firstUnsupportedType.get() + "] in model" + ); + } } else { throw new ParsingException(parser.getTokenLocation(), "Expected [START_OBJECT] but got [" + startToken + "]"); } @@ -355,10 +369,10 @@ private boolean isSplit(Integer nodeId) { private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) { if (nodeId >= leftChildren.size()) { - throw new IllegalArgumentException("Node ID [" + nodeId + "] is invalid"); + throw new IllegalArgumentException("Child node reference ID [" + nodeId + "] is invalid"); } if (nodeId >= rightChildren.size()) { - throw new IllegalArgumentException("Node ID [" + nodeId + "] is invalid"); + throw new IllegalArgumentException("Child node reference ID [" + nodeId + "] is invalid"); } if (isSplit(nodeId)) { diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java index 36fa0882..fd7e5477 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java @@ -12,6 +12,7 @@ import org.junit.rules.ExpectedException; import java.io.IOException; +import java.util.List; import static com.o19s.es.ltr.LtrTestUtils.randomFeature; import static java.util.Collections.singletonList; @@ -193,6 +194,115 @@ public void testBadObjectiveParam() throws IOException { CoreMatchers.containsString("Unable to parse XGBoost object")); } + public void testBadFeatureTypeParam() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("The LTR plugin only supports float feature types because " + + "Elasticsearch scores are always float32. Found feature type [int] in model")); + } + + public void testMismatchingFeatureList() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:logistic\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", List.of(randomFeature("feat1"), randomFeature("feat2"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Feature names list and feature types list must have the same length")); + } + public void testSplitMissingLeftChild() throws IOException { String model = "{" + @@ -250,7 +360,7 @@ public void testSplitMissingLeftChild() throws IOException { assertThat(e.getMessage(), CoreMatchers.containsString("Unable to parse XGBoost object")); Throwable rootCause = e.getCause().getCause().getCause().getCause().getCause().getCause(); assertThat(rootCause, CoreMatchers.instanceOf(IllegalArgumentException.class)); - assertThat(rootCause.getMessage(), CoreMatchers.containsString("Node ID [100] is invalid")); + assertThat(rootCause.getMessage(), CoreMatchers.containsString("Child node reference ID [100] is invalid")); } } @@ -311,7 +421,7 @@ public void testSplitMissingRightChild() throws IOException { assertThat(e.getMessage(), CoreMatchers.containsString("Unable to parse XGBoost object")); Throwable rootCause = e.getCause().getCause().getCause().getCause().getCause().getCause(); assertThat(rootCause, CoreMatchers.instanceOf(IllegalArgumentException.class)); - assertThat(rootCause.getMessage(), CoreMatchers.containsString("Node ID [100] is invalid")); + assertThat(rootCause.getMessage(), CoreMatchers.containsString("Child node reference ID [100] is invalid")); } } From ca118b49b8f73c100e8767e0b23429f612a71727 Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Mon, 7 Oct 2024 11:01:40 +0200 Subject: [PATCH 08/10] newline --- .../o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java index fd7e5477..fb261103 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java @@ -530,4 +530,4 @@ public void testMissingFeat() throws IOException { assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), CoreMatchers.containsString("Unknown features in model: [feat1, feat2]")); } -} \ No newline at end of file +} From 8075a3f67618a36b4196a0d30fe0d8a62b813948 Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Wed, 9 Oct 2024 14:33:25 +0200 Subject: [PATCH 09/10] fix code style --- .../ranker/parser/XGBoostRawJsonParser.java | 56 +++++++++++---- .../parser/XGBoostRawJsonParserTests.java | 71 +++++++++++++++---- 2 files changed, 100 insertions(+), 27 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java index 5004c57a..5dcfe271 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java @@ -5,11 +5,20 @@ import com.o19s.es.ltr.ranker.normalizer.Normalizer; import com.o19s.es.ltr.ranker.normalizer.Normalizers; import org.elasticsearch.common.ParsingException; -import org.elasticsearch.xcontent.*; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; -import java.util.*; +//import java.util.*; +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; +import java.util.Optional; +import java.util.ListIterator; public class XGBoostRawJsonParser implements LtrRankerParser { @@ -39,7 +48,11 @@ private static class XGBoostDefinition { static { PARSER = new ObjectParser<>("xgboost_definition", true, XGBoostRawJsonParser.XGBoostDefinition::new); - PARSER.declareObject(XGBoostRawJsonParser.XGBoostDefinition::setLearner, XGBoostRawJsonParser.XGBoostLearner::parse, new ParseField("learner")); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostDefinition::setLearner, + XGBoostRawJsonParser.XGBoostLearner::parse, + new ParseField("learner") + ); PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostDefinition::setVersion, new ParseField("version")); } @@ -63,7 +76,8 @@ public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser } } if (!unknownFeatures.isEmpty()) { - throw new ParsingException(parser.getTokenLocation(), "Unknown features in model: [" + String.join(", ", unknownFeatures) + "]"); + throw new ParsingException(parser.getTokenLocation(), "Unknown features in model: [" + + String.join(", ", unknownFeatures) + "]"); } if (definition.learner.featureNames.size() != definition.learner.featureTypes.size()) { throw new ParsingException(parser.getTokenLocation(), @@ -117,8 +131,16 @@ static class XGBoostLearner { static { PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostRawJsonParser.XGBoostLearner::new); - PARSER.declareObject(XGBoostRawJsonParser.XGBoostLearner::setObjective, XGBoostRawJsonParser.XGBoostObjective::parse, new ParseField("objective")); - PARSER.declareObject(XGBoostRawJsonParser.XGBoostLearner::setGradientBooster, XGBoostRawJsonParser.XGBoostGradientBooster::parse, new ParseField("gradient_booster")); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostLearner::setObjective, + XGBoostRawJsonParser.XGBoostObjective::parse, + new ParseField("objective") + ); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostLearner::setGradientBooster, + XGBoostRawJsonParser.XGBoostGradientBooster::parse, + new ParseField("gradient_booster") + ); PARSER.declareStringArray(XGBoostRawJsonParser.XGBoostLearner::setFeatureNames, new ParseField("feature_names")); PARSER.declareStringArray(XGBoostRawJsonParser.XGBoostLearner::setFeatureTypes, new ParseField("feature_types")); } @@ -142,7 +164,6 @@ NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { return this.getGradientBooster().getModel().getTrees(); } - public XGBoostObjective getObjective() { return objective; } @@ -167,14 +188,18 @@ static class XGBoostGradientBooster { static { PARSER = new ObjectParser<>("xgboost_gradient_booster", true, XGBoostRawJsonParser.XGBoostGradientBooster::new); - PARSER.declareObject(XGBoostRawJsonParser.XGBoostGradientBooster::setModel, XGBoostRawJsonParser.XGBoostModel::parse, new ParseField("model")); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostGradientBooster::setModel, + XGBoostRawJsonParser.XGBoostModel::parse, + new ParseField("model") + ); } - public static XGBoostRawJsonParser.XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException { + static XGBoostRawJsonParser.XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException { return PARSER.apply(parser, set); } - public XGBoostGradientBooster() { + XGBoostGradientBooster() { } public XGBoostModel getModel() { @@ -194,7 +219,11 @@ static class XGBoostModel { static { PARSER = new ObjectParser<>("xgboost_model", true, XGBoostRawJsonParser.XGBoostModel::new); - PARSER.declareObjectArray(XGBoostRawJsonParser.XGBoostModel::setTrees, XGBoostRawJsonParser.XGBoostTree::parse, new ParseField("trees")); + PARSER.declareObjectArray( + XGBoostRawJsonParser.XGBoostModel::setTrees, + XGBoostRawJsonParser.XGBoostTree::parse, + new ParseField("trees") + ); PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostModel::setTreeInfo, new ParseField("tree_info")); } @@ -214,7 +243,7 @@ public static XGBoostRawJsonParser.XGBoostModel parse(XContentParser parser, Fea } } - public XGBoostModel() { + XGBoostModel() { } public NaiveAdditiveDecisionTree.Node[] getTrees() { @@ -245,10 +274,9 @@ public static XGBoostRawJsonParser.XGBoostObjective parse(XContentParser parser, return PARSER.apply(parser, set); } - public XGBoostObjective() { + XGBoostObjective() { } - public void setName(String name) { switch (name) { case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java index fb261103..857bee46 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java @@ -53,7 +53,12 @@ public void testSimpleSplit() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -61,7 +66,7 @@ public void testSimpleSplit() throws IOException { " \"base_score\":\"5E-1\"," + " \"boost_from_average\":\"1\"," + " \"num_class\":\"0\"," + - " \"num_feature\":\"2\"," + + " \"num_feature\":\"1\"," + " \"num_target\":\"1\"" + " }," + " \"objective\":{" + @@ -112,7 +117,12 @@ public void testReadWithLogisticObjective() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -170,7 +180,12 @@ public void testBadObjectiveParam() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -178,7 +193,7 @@ public void testBadObjectiveParam() throws IOException { " \"base_score\":\"5E-1\"," + " \"boost_from_average\":\"1\"," + " \"num_class\":\"0\"," + - " \"num_feature\":\"1\"," + + " \"num_feature\":\"2\"," + " \"num_target\":\"1\"" + " }," + " \"objective\":{" + @@ -224,7 +239,12 @@ public void testBadFeatureTypeParam() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -279,7 +299,12 @@ public void testMismatchingFeatureList() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -287,7 +312,7 @@ public void testMismatchingFeatureList() throws IOException { " \"base_score\":\"5E-1\"," + " \"boost_from_average\":\"1\"," + " \"num_class\":\"0\"," + - " \"num_feature\":\"1\"," + + " \"num_feature\":\"2\"," + " \"num_target\":\"1\"" + " }," + " \"objective\":{" + @@ -333,7 +358,12 @@ public void testSplitMissingLeftChild() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -394,7 +424,12 @@ public void testSplitMissingRightChild() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -455,7 +490,12 @@ public void testBadStruct() throws IOException { " \"split_indices\":[0, 0, 0]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"1\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + @@ -463,7 +503,7 @@ public void testBadStruct() throws IOException { " \"base_score\":\"5E-1\"," + " \"boost_from_average\":\"1\"," + " \"num_class\":\"0\"," + - " \"num_feature\":\"1\"," + + " \"num_feature\":\"2\"," + " \"num_target\":\"1\"" + " }," + " \"objective\":{" + @@ -508,7 +548,12 @@ public void testMissingFeat() throws IOException { " \"split_indices\":[0, 0, 100]," + " \"split_type\":[0, 0, 0]," + " \"sum_hessian\":[1E0, 1E0, 1E0]," + - " \"tree_param\":{\"num_deleted\":\"0\",\"num_feature\":\"2\",\"num_nodes\":\"3\",\"size_leaf_vector\":\"1\"}}" + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + " ]}," + " \"name\":\"gbtree\"" + " }," + From 2f88306b2a12d27b526c9ca69df057c31cd2cd2f Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Wed, 9 Oct 2024 14:35:24 +0200 Subject: [PATCH 10/10] fix code style --- .../o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java index 5dcfe271..658e7e76 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java @@ -5,20 +5,19 @@ import com.o19s.es.ltr.ranker.normalizer.Normalizer; import com.o19s.es.ltr.ranker.normalizer.Normalizers; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; -//import java.util.*; import java.util.Arrays; -import java.util.List; import java.util.ArrayList; -import java.util.Optional; +import java.util.List; import java.util.ListIterator; +import java.util.Optional; public class XGBoostRawJsonParser implements LtrRankerParser {