diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java index 1ea894994..fa3ed7e6c 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java @@ -54,8 +54,8 @@ protected float computeSupervisedLoss() { int proxRuleIndex = rvAtomIndexToProxRuleIndex.get(atomIndex); - supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)) - + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT))); + supervisedLoss += (float) (-1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)) + + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)))); } return supervisedLoss; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java index 1b949959f..e6a3b761a 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java @@ -51,7 +51,7 @@ protected float computeSupervisedLoss() { continue; } - supervisedLoss += Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f); + supervisedLoss += (float) Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f); } return supervisedLoss;