diff --git a/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java b/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java index 72dc499cb..b1de213da 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/predicate/DeepPredicate.java @@ -237,9 +237,12 @@ public static void nextBatchAllDeepPredicates() { */ public static boolean isEpochCompleteAllDeepPredicates() { boolean isEpochComplete = false; + boolean deepPredicates = false; for (Predicate predicate : Predicate.getAll()) { if (predicate instanceof DeepPredicate) { + deepPredicates = true; + isEpochComplete = (((DeepPredicate) predicate).isEpochComplete()); if (isEpochComplete) { @@ -248,7 +251,7 @@ public static boolean isEpochCompleteAllDeepPredicates() { } } - return isEpochComplete; + return (!deepPredicates) || isEpochComplete; } private void writeObject(ObjectOutputStream out) throws IOException { diff --git a/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java b/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java index 24d8ffb8f..c9a3057f0 100644 --- a/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java +++ b/psl-java/src/main/java/org/linqs/psl/config/RuntimeOptions.java @@ -125,12 +125,6 @@ public class RuntimeOptions { "Clear learning rules before inference. Useful when switching models between train and test." ); - public static final Option INFERENCE_DEEP_BATCHING = new Option( - "runtime.inference.deep.batching", - false, - "Whether deep models are batched. Inference should rerun until deep models terminates." - ); - public static final Option LEARN = new Option( "runtime.learn", false, diff --git a/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java b/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java index 4307af55d..d1c742baa 100644 --- a/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java +++ b/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java @@ -448,12 +448,7 @@ protected void runInferenceInternal(RuntimeConfig config, Model model, RuntimeRe DeepPredicate.evalAllDeepPredicates(); - if (RuntimeOptions.INFERENCE_DEEP_BATCHING.getBoolean()) { - DeepPredicate.nextBatchAllDeepPredicates(); - runInference = !DeepPredicate.isEpochCompleteAllDeepPredicates(); - } else { - runInference = false; - } + runInference = !DeepPredicate.isEpochCompleteAllDeepPredicates(); } DeepPredicate.epochEndAllDeepPredicates();