From f5bf9b96e82774f05e09733de6a66f91e1fc00be Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Tue, 16 Jan 2024 13:00:50 -0800 Subject: [PATCH] Cast loss copmutations in learning as floats. --- .../weight/gradient/minimizer/BinaryCrossEntropy.java | 4 ++-- .../learning/weight/gradient/minimizer/SquaredError.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java index 1ea894994..fa3ed7e6c 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.java @@ -54,8 +54,8 @@ protected float computeSupervisedLoss() { int proxRuleIndex = rvAtomIndexToProxRuleIndex.get(atomIndex); - supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)) - + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT))); + supervisedLoss += (float) (-1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)) + + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)))); } return supervisedLoss; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java index 1b949959f..e6a3b761a 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/SquaredError.java @@ -51,7 +51,7 @@ protected float computeSupervisedLoss() { continue; } - supervisedLoss += Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f); + supervisedLoss += (float) Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f); } return supervisedLoss;