From 26d79c7b4d2e5e7f40d4a8635017759c2fc4bf70 Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Sun, 10 Dec 2023 20:33:48 -0600 Subject: [PATCH] Update policy gradient learning and learning logs. --- .../weight/gradient/GradientDescent.java | 45 +++-- .../weight/gradient/minimizer/Minimizer.java | 4 - .../onlineminimizer/OnlineMinimizer.java | 4 - .../policygradient/PolicyGradient.java | 154 +++++++++++------- .../PolicyGradientBinaryCrossEntropy.java | 12 +- .../PolicyGradientSquaredError.java | 6 + .../java/org/linqs/psl/config/Options.java | 12 +- .../psl/model/deep/DeepModelPredicate.java | 4 + .../linqs/psl/model/predicate/Predicate.java | 4 +- .../java/org/linqs/psl/reasoner/Reasoner.java | 2 +- 10 files changed, 154 insertions(+), 93 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index ff9fba5e2..d3b9be16b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -54,13 +54,15 @@ public abstract class GradientDescent extends WeightLearningApplication { * on the chosen loss with unit simplex constrained weights. * NONE: Perform standard gradient descent with only lower bound (>=0) constraints on the weights. */ - public static enum GDExtension { + public static enum SymbolicWeightUpdate { MIRROR_DESCENT, PROJECTED_GRADIENT, - NONE + GRADIENT_DESCENT } - protected GDExtension gdExtension; + protected boolean symbolicWeightLearning; + protected SymbolicWeightUpdate symbolicWeightUpdate; + protected Map ruleIndexMap; protected float[] weightGradient; @@ -124,7 +126,8 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); - gdExtension = GDExtension.valueOf(Options.WLA_GRADIENT_DESCENT_EXTENSION.getString().toUpperCase()); + symbolicWeightLearning = Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.getBoolean(); + symbolicWeightUpdate = SymbolicWeightUpdate.valueOf(Options.WLA_GRADIENT_DESCENT_EXTENSION.getString().toUpperCase()); ruleIndexMap = new HashMap(mutableRules.size()); for (int i = 0; i < mutableRules.size(); i++) { @@ -286,7 +289,7 @@ protected void initializeGradients() { } protected void initForLearning() { - switch (gdExtension) { + switch (symbolicWeightUpdate) { case MIRROR_DESCENT: case PROJECTED_GRADIENT: // Initialize weights to be centered on the unit simplex. @@ -348,10 +351,12 @@ protected void doLearn() { DeepPredicate.trainModeAllDeepPredicates(); + int numBatches = 0; + float averageBatchObjective = 0.0f; batchGenerator.permuteBatchOrdering(); int batchId = batchGenerator.epochStart(); while (!batchGenerator.isEpochComplete()) { - long batchStart = System.currentTimeMillis(); + numBatches++; setBatch(batchId); DeepPredicate.predictAllDeepPredicates(); @@ -364,7 +369,7 @@ protected void doLearn() { clipWeightGradient(); } - float batchObjective = computeTotalLoss(); + averageBatchObjective += computeTotalLoss(); gradientStep(epoch); @@ -372,15 +377,15 @@ protected void doLearn() { epochDeepAtomValueMovement += DeepPredicate.predictAllDeepPredicates(); } - long batchEnd = System.currentTimeMillis(); - - log.trace("Batch: {} -- Weight Learning Objective: {}, Gradient Magnitude: {}, Iteration Time: {}", - batchId, batchObjective, computeGradientNorm(), (batchEnd - batchStart)); - batchId = batchGenerator.nextBatch(); } batchGenerator.epochEnd(); + if (numBatches > 0) { + // Average the objective across batches. + averageBatchObjective /= numBatches; + } + setFullModel(); long end = System.currentTimeMillis(); @@ -396,7 +401,7 @@ protected void doLearn() { setFullModel(); epoch++; - log.trace("Epoch: {} -- Iteration Time: {}", epoch, (end - start)); + log.trace("Epoch: {}, Weight Learning Objective: {}, Iteration Time: {}", epoch, averageBatchObjective, (end - start)); } log.info("Gradient Descent Weight Learning Finished."); @@ -655,9 +660,13 @@ protected void internalParameterGradientStep(int epoch) { * Return the total change in the weights. */ protected void weightGradientStep(int epoch) { + if (!symbolicWeightLearning) { + return; + } + float stepSize = computeStepSize(epoch); - switch (gdExtension) { + switch (symbolicWeightUpdate) { case MIRROR_DESCENT: float exponentiatedGradientSum = 0.0f; for (int j = 0; j < mutableRules.size(); j++) { @@ -713,7 +722,7 @@ protected float computeStepSize(int epoch) { protected float computeGradientNorm() { float norm = 0.0f; - switch (gdExtension) { + switch (symbolicWeightUpdate) { case MIRROR_DESCENT: norm = computeMirrorDescentNorm(); break; @@ -931,8 +940,6 @@ protected float computeTotalLoss() { float learningLoss = computeLearningLoss(); float regularization = computeRegularization(); - log.trace("Learning Loss: {}, Regularization: {}", learningLoss, regularization); - return learningLoss + regularization; } @@ -963,6 +970,10 @@ protected float computeRegularization() { protected void computeTotalWeightGradient() { Arrays.fill(weightGradient, 0.0f); + if (!symbolicWeightLearning) { + return; + } + addLearningLossWeightGradient(); addRegularizationWeightGradient(); } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java index 65c4f3a61..2b3819b41 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java @@ -611,10 +611,6 @@ protected float computeLearningLoss() { float objectiveDifference = augmentedInferenceEnergy - mapEnergy; float constraintViolation = Math.max(0.0f, objectiveDifference - constraintRelaxationConstant); float supervisedLoss = computeSupervisedLoss(); - float totalProxValue = computeTotalProxValue(new float[proxRuleObservedAtoms.length]); - - log.trace("Prox Loss: {}, Objective difference: {}, Constraint Violation: {}, Supervised Loss: {}, Energy Loss: {}.", - totalProxValue, objectiveDifference, constraintViolation, supervisedLoss, latentInferenceEnergy); return (squaredPenaltyCoefficient / 2.0f) * (float)Math.pow(constraintViolation, 2.0f) + linearPenaltyCoefficient * (constraintViolation) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java index a55d14dd1..83f6038b4 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java @@ -542,10 +542,6 @@ protected float computeLearningLoss() { float objectiveDifference = augmentedInferenceEnergy - mapEnergy; float constraintViolation = Math.max(0.0f, objectiveDifference - constraintRelaxationConstant); float supervisedLoss = computeSupervisedLoss(); - float totalProxValue = computeTotalProxValue(new float[proxRuleObservedAtoms.length]); - - log.trace("Prox Loss: {}, Objective difference: {}, Constraint Violation: {}, Supervised Loss: {}, Energy Loss: {}.", - totalProxValue, objectiveDifference, constraintViolation, supervisedLoss, latentInferenceEnergy); return (squaredPenaltyCoefficient / 2.0f) * (float)Math.pow(constraintViolation, 2.0f) + linearPenaltyCoefficient * (constraintViolation) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java index 1f0b41692..9f07626e1 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java @@ -8,7 +8,6 @@ import org.linqs.psl.model.atom.ObservedAtom; import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.deep.DeepModelPredicate; -import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.reasoner.term.SimpleTermStore; @@ -39,6 +38,7 @@ public enum PolicyUpdate { private final DeepAtomPolicyDistribution deepAtomPolicyDistribution; private final PolicyUpdate policyUpdate; + private float[] scores; private float scoreMovingAverage; private float[] sampleProbabilities; @@ -65,6 +65,8 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase()); policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase()); + + scores = null; scoreMovingAverage = 0.0f; sampleProbabilities = null; @@ -91,15 +93,17 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t protected void initForLearning() { super.initForLearning(); - scoreMovingAverage = 0.0f; + if (symbolicWeightLearning){ + throw new IllegalArgumentException("Policy Gradient does not support symbolic weight learning."); + } + + scoreMovingAverage = Float.POSITIVE_INFINITY; } protected abstract void computeSupervisedLoss(); @Override protected float computeLearningLoss() { - log.trace("Supervised Loss: {}, Energy Loss: {}.", supervisedLoss, latentInferenceEnergy); - return supervisedLoss + energyLossCoefficient * latentInferenceEnergy; } @@ -130,93 +134,102 @@ protected void setBatch(int batch) { latentInferenceAtomValueState = batchLatentInferenceAtomValueStates.get(batch); initialDeepAtomValues = new float[trainInferenceApplication.getTermStore().getAtomStore().size()]; policySampledDeepAtomValues = new float[trainInferenceApplication.getTermStore().getAtomStore().size()]; + scores = new float[trainInferenceApplication.getTermStore().getAtomStore().size()]; } @Override protected void computeIterationStatistics() { - sampleDeepAtomValues(); - // TODO(Charles): Sample symbolic weights. + Arrays.fill(scores, 0.0f); + + for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) { + Map> atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories(); + for (Map.Entry> entry : atomIdentiferToCategories.entrySet()) { + ArrayList categories = entry.getValue(); + + sampleDeepAtomValues(categories); + + computeMAPInferenceStatistics(); + computeSupervisedLoss(); + computeLatentInferenceStatistics(); + + float score = computeScore(); + for (RandomVariableAtom category : categories) { + int atomIndex = trainInferenceApplication.getTermStore().getAtomStore().getAtomIndex(category); + scores[atomIndex] = score; + } + + resetDeepAtomValues(categories); + } + } + updateScoreMovingAverage(); computeMAPInferenceStatistics(); computeSupervisedLoss(); computeLatentInferenceStatistics(); - - // TODO(Charles): Reset symbolic weights. - resetDeepAtomValues(); } /** * Sample the deep atom values according to a policy parameterized by the deep model predictions. */ - protected void sampleDeepAtomValues() { + protected void sampleDeepAtomValues(ArrayList categories) { // Save the initial deep model predictions to reset the deep atom values after computing iteration statistics. AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - for (int i = 0; i < atomStore.size(); i++) { - GroundAtom atom = atomStore.getAtoms()[i]; - if (atom.getPredicate() instanceof DeepPredicate) { - initialDeepAtomValues[i] = atom.getValue(); - } else { - initialDeepAtomValues[i] = 0.0f; - policySampledDeepAtomValues[i] = 0.0f; - } + + Arrays.fill(initialDeepAtomValues, 0.0f); + Arrays.fill(policySampledDeepAtomValues, 0.0f); + + for (RandomVariableAtom category : categories) { + int atomIndex = atomStore.getAtomIndex(category); + initialDeepAtomValues[atomIndex] = atomStore.getAtomValues()[atomIndex]; } switch (deepAtomPolicyDistribution) { case CATEGORICAL: - sampleCategorical(); + sampleCategorical(categories); break; default: throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution); } } - private void sampleCategorical() { + private void sampleCategorical(ArrayList categories) { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); float[] atomValues = atomStore.getAtomValues(); sampleProbabilities = new float[atomStore.size()]; // Sample the deep model predictions according to the stochastic categorical policy. - for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) { - Map > atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories(); - for (Map.Entry> entry : atomIdentiferToCategories.entrySet()) { - ArrayList categories = entry.getValue(); - - float[] categoryProbabilities = new float[categories.size()]; - for (int i = 0; i < categories.size(); i++) { - categoryProbabilities[i] = categories.get(i).getValue(); - } - - int sampledCategoryIndex = RandUtils.sampleCategorical(categoryProbabilities); + float[] categoryProbabilities = new float[categories.size()]; + for (int i = 0; i < categories.size(); i++) { + categoryProbabilities[i] = categories.get(i).getValue(); + } - for (int i = 0; i < categories.size(); i++) { - int atomIndex = atomStore.getAtomIndex(categories.get(i)); + int sampledCategoryIndex = RandUtils.sampleCategorical(categoryProbabilities); - if (i != sampledCategoryIndex) { - categories.get(i).setValue(0.0f); - } else { - sampleProbabilities[atomIndex] = categoryProbabilities[i]; - categories.get(i).setValue(1.0f); - } + for (int i = 0; i < categories.size(); i++) { + int atomIndex = atomStore.getAtomIndex(categories.get(i)); - policySampledDeepAtomValues[atomIndex] = categories.get(i).getValue(); - atomValues[atomIndex] = categories.get(i).getValue(); - } + if (i != sampledCategoryIndex) { + categories.get(i).setValue(0.0f); + } else { + sampleProbabilities[atomIndex] = categoryProbabilities[i]; + categories.get(i).setValue(1.0f); } + + policySampledDeepAtomValues[atomIndex] = categories.get(i).getValue(); + atomValues[atomIndex] = categories.get(i).getValue(); } } - private void resetDeepAtomValues() { + private void resetDeepAtomValues(ArrayList categories) { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); float[] atomValues = atomStore.getAtomValues(); - for (int i = 0; i < atomStore.size(); i++) { - GroundAtom atom = atomStore.getAtoms()[i]; - if (atom.getPredicate() instanceof DeepPredicate) { - ((RandomVariableAtom) atom).setValue(initialDeepAtomValues[i]); - atomValues[i] = initialDeepAtomValues[i]; - } + for (RandomVariableAtom category : categories) { + int atomIndex = atomStore.getAtomIndex(category); + category.setValue(initialDeepAtomValues[atomIndex]); + atomValues[atomIndex] = initialDeepAtomValues[atomIndex]; } } @@ -253,13 +266,7 @@ protected void computeLatentInferenceStatistics() { @Override protected void addLearningLossWeightGradient() { - // TODO(Charles): Energy Loss Policy Gradient. - // Energy loss gradient. - for (int i = 0; i < mutableRules.size(); i++) { - weightGradient[i] += energyLossCoefficient * latentInferenceIncompatibility[i]; - } - - // TODO(Charles): Supervised loss gradient. + throw new UnsupportedOperationException("Policy Gradient does not support learning symbolic weights."); } @Override @@ -291,21 +298,46 @@ private void computeCategoricalAtomGradient(int atomIndex) { return; } - float score = energyLossCoefficient * latentInferenceEnergy + supervisedLoss; - scoreMovingAverage = 0.9f * scoreMovingAverage + 0.1f * score; - switch (policyUpdate) { case REINFORCE: - deepAtomGradient[atomIndex] += score / sampleProbabilities[atomIndex]; + deepAtomGradient[atomIndex] += scores[atomIndex] / sampleProbabilities[atomIndex]; break; case REINFORCE_BASELINE: - deepAtomGradient[atomIndex] += (score - scoreMovingAverage) / sampleProbabilities[atomIndex]; + deepAtomGradient[atomIndex] += (scores[atomIndex] - scoreMovingAverage) / sampleProbabilities[atomIndex]; break; default: throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); } } + private float computeScore() { + return energyLossCoefficient * latentInferenceEnergy + supervisedLoss; + } + + private void updateScoreMovingAverage() { + float scoreAverage = 0.0f; + int numScores = 0; + for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) { + Map> atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories(); + for (Map.Entry> entry : atomIdentiferToCategories.entrySet()) { + ArrayList categories = entry.getValue(); + + for (RandomVariableAtom category : categories) { + int atomIndex = trainInferenceApplication.getTermStore().getAtomStore().getAtomIndex(category); + scoreAverage += scores[atomIndex]; + numScores += 1; + } + } + } + scoreAverage /= numScores; + + if (!Float.isInfinite(scoreMovingAverage)) { + scoreMovingAverage = 0.9f * scoreMovingAverage + 0.1f * scoreAverage; + } else { + scoreMovingAverage = scoreAverage; + } + } + /** * Set RandomVariableAtoms with labels to their observed (truth) value. diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java index a081485a8..799f91beb 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java @@ -22,6 +22,7 @@ import org.linqs.psl.model.atom.ObservedAtom; import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.util.Logger; import org.linqs.psl.util.MathUtils; import java.util.List; @@ -32,6 +33,8 @@ * using the policy gradient learning framework. */ public class PolicyGradientBinaryCrossEntropy extends PolicyGradient { + private static final Logger log = Logger.getLogger(PolicyGradientBinaryCrossEntropy.class); + public PolicyGradientBinaryCrossEntropy(List rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); @@ -42,7 +45,8 @@ protected void computeSupervisedLoss() { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); supervisedLoss = 0.0f; - for (Map.Entry entry: trainingMap.getLabelMap().entrySet()) { + int numEvaluatedAtoms = 0; + for (Map.Entry entry : trainingMap.getLabelMap().entrySet()) { RandomVariableAtom randomVariableAtom = entry.getKey(); ObservedAtom observedAtom = entry.getValue(); @@ -54,6 +58,12 @@ protected void computeSupervisedLoss() { supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)) + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT))); + + numEvaluatedAtoms++; + } + + if (numEvaluatedAtoms > 0) { + supervisedLoss /= numEvaluatedAtoms; } } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java index 48d78cea5..77b356d38 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java @@ -41,6 +41,7 @@ protected void computeSupervisedLoss() { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); supervisedLoss = 0.0f; + int numEvaluatedAtoms = 0; for (Map.Entry entry: trainingMap.getLabelMap().entrySet()) { RandomVariableAtom randomVariableAtom = entry.getKey(); ObservedAtom observedAtom = entry.getValue(); @@ -52,6 +53,11 @@ protected void computeSupervisedLoss() { } supervisedLoss += Math.pow(atomStore.getAtom(atomIndex).getValue() - observedAtom.getValue(), 2.0f); + numEvaluatedAtoms++; + } + + if (numEvaluatedAtoms > 0) { + supervisedLoss /= numEvaluatedAtoms; } } } diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index a7ec3358c..2224c35fb 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -322,13 +322,19 @@ public class Options { ); public static final Option WLA_GRADIENT_DESCENT_EXTENSION = new Option( - "gradientdescent.extension", - GradientDescent.GDExtension.MIRROR_DESCENT.toString(), + "gradientdescent.symbolicweightupdate", + GradientDescent.SymbolicWeightUpdate.MIRROR_DESCENT.toString(), "The gradient descent extension to use for gradient descent weight learning." + " MIRROR_DESCENT (Default): Mirror descent / normalized exponentiated gradient descent over the unit simplex." + " If this option is chosen then gradientdescent.negativelogregularization must be positive." + " PROJECTED_GRADIENT: Projected gradient descent over the unit simplex." - + " NONE: Gradient descent over non-negative orthant." + + " GRADIENT_DESCENT: Gradient descent over non-negative orthant." + ); + + public static final Option WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING = new Option( + "gradientdescent.symbolicweightlearning", + true, + "Whether to perform symbolic weight learning during gradient descent." ); public static final Option WLA_GRADIENT_DESCENT_L2_REGULARIZATION = new Option( diff --git a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java index a41e89199..a76259f21 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java @@ -173,6 +173,10 @@ public int init() { return Integer.SIZE + maxDataIndex * Integer.SIZE + maxDataIndex * classSize * Float.SIZE; } + public Predicate getPredicate() { + return predicate; + } + public void writeFitData() { log.debug("Writing fit data for deep model predicate: {}", predicate.getName()); for (int index = 0; index < gradients.length; index++) { diff --git a/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java b/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java index 5c9094abf..d4c31e030 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java @@ -128,11 +128,11 @@ public void setPredicateOption(String name, Object option) { integer = true; } - if (name.equals("categorical") && Boolean.parseBoolean(option.toString())) { + if (name.equals("Categorical") && Boolean.parseBoolean(option.toString())) { categorical = true; } - if (name.equals("categoricalIndexes")) { + if (name.equals("CategoricalIndexes")) { categoryIndexes = StringUtils.splitInt(option.toString(), DELIM); for (int categoryIndex : categoryIndexes) { identifierIndexes[categoryIndex] = -1; diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java b/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java index 44057db5d..bf5b7fa1d 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/Reasoner.java @@ -130,7 +130,7 @@ protected void initForOptimization(TermStore termStore) { protected void optimizationComplete(TermStore termStore, ObjectiveResult finalObjective, long totalTime) { float change = (float)termStore.sync(); - log.info("Final Objective: {}, Violated Constraints: {}, Total Optimization Time: {}", + log.debug("Final Objective: {}, Violated Constraints: {}, Total Optimization Time: {}", finalObjective.objective, finalObjective.violatedConstraints, totalTime); log.debug("Movement of variables from initial state: {}", change);