From 936aa5d6dcd586c18ed71129df627fab3ac6a34a Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Tue, 28 Nov 2023 10:32:56 -0800 Subject: [PATCH] Use batch atom stores to create batch training maps. --- .../inference/InferenceApplication.java | 2 +- .../learning/weight/TrainingMap.java | 20 ++++----- .../weight/WeightLearningApplication.java | 4 +- .../weight/gradient/GradientDescent.java | 16 ++++--- .../batchgenerator/BatchGenerator.java | 44 ++++++++++++++++--- .../ConnectedComponentBatchGenerator.java | 14 ++++-- .../batchgenerator/FullBatchGenerator.java | 9 ++-- .../batchgenerator/NeuralBatchGenerator.java | 9 ++-- .../psl/evaluation/statistics/Evaluator.java | 2 +- .../learning/weight/TrainingMapTest.java | 2 +- .../statistics/CategoricalEvaluatorTest.java | 2 +- .../evaluation/statistics/EvaluatorTest.java | 2 +- 12 files changed, 89 insertions(+), 37 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java b/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java index 80cc40234..a9246e3a8 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java +++ b/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java @@ -170,7 +170,7 @@ public double inference(boolean commitAtoms, boolean reset, List 0) { - trainingMap = new TrainingMap(database, truthDatabase); + trainingMap = new TrainingMap(database.getAtomStore(), truthDatabase.getAtomStore()); } log.info("Beginning inference."); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/TrainingMap.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/TrainingMap.java index b11f7f5f8..447698a8f 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/TrainingMap.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/TrainingMap.java @@ -17,12 +17,12 @@ */ package org.linqs.psl.application.learning.weight; +import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.atom.ObservedAtom; import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.predicate.FunctionalPredicate; -import org.linqs.psl.model.predicate.StandardPredicate; import org.linqs.psl.util.IteratorUtils; import org.linqs.psl.util.Logger; @@ -84,11 +84,11 @@ public class TrainingMap { /** * Initializes the training map of RandomVariableAtoms ObservedAtoms. * - * @param targetDatabase the database containing the RandomVariableAtoms (any other atom types are ignored) - * @param truthDatabase the database containing matching ObservedAtoms + * @param targetAtomStore the atomstore containing the RandomVariableAtoms (any other atom types are ignored) + * @param truthAtomStore the atomstore containing matching ObservedAtoms */ - public TrainingMap(Database targetDatabase, Database truthDatabase) { - labelMap = new HashMap(targetDatabase.getAtomStore().size()); + public TrainingMap(AtomStore targetAtomStore, AtomStore truthAtomStore) { + labelMap = new HashMap(targetAtomStore.size()); observedMap = new HashMap(); latentVariables = new ArrayList(); missingLabels = new ArrayList(); @@ -96,15 +96,15 @@ public TrainingMap(Database targetDatabase, Database truthDatabase) { Set seenTruthAtoms = new HashSet(); - for (GroundAtom targetAtom : targetDatabase.getAtomStore()) { + for (GroundAtom targetAtom : targetAtomStore) { if (targetAtom.getPredicate() instanceof FunctionalPredicate) { continue; } // Note that hasAtom() will not return true for an unmanaged atom (except pre-cached functional predicates). GroundAtom truthAtom = null; - if (truthDatabase.getAtomStore().hasAtom(targetAtom.getPredicate(), targetAtom.getArguments())) { - truthAtom = truthDatabase.getAtomStore().getAtom(targetAtom.getPredicate(), targetAtom.getArguments()); + if (truthAtomStore.hasAtom(targetAtom)) { + truthAtom = truthAtomStore.getAtom(truthAtomStore.getAtomIndex(targetAtom)); } // Skip any truth atom that is not observed. @@ -129,12 +129,12 @@ public TrainingMap(Database targetDatabase, Database truthDatabase) { } } - for (GroundAtom truthAtom : truthDatabase.getAtomStore()) { + for (GroundAtom truthAtom : truthAtomStore) { if (!(truthAtom instanceof ObservedAtom) || seenTruthAtoms.contains(truthAtom)) { continue; } - boolean hasAtom = targetDatabase.getAtomStore().hasAtom(truthAtom.getPredicate(), truthAtom.getArguments()); + boolean hasAtom = targetAtomStore.hasAtom(truthAtom); if (hasAtom) { // This shouldn't be possible (since we already iterated through the target atoms). // This means that the target is not cached. diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java index e04efd05e..915d08aca 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/WeightLearningApplication.java @@ -178,8 +178,8 @@ private void initGroundModel(InferenceApplication trainInferenceApplication, Inf return; } - TrainingMap trainingMap = new TrainingMap(trainInferenceApplication.getDatabase(), trainTruthDatabase); - TrainingMap validationMap = new TrainingMap(validationInferenceApplication.getDatabase(), validationTruthDatabase); + TrainingMap trainingMap = new TrainingMap(trainInferenceApplication.getDatabase().getAtomStore(), trainTruthDatabase.getAtomStore()); + TrainingMap validationMap = new TrainingMap(validationInferenceApplication.getDatabase().getAtomStore(), validationTruthDatabase.getAtomStore()); initGroundModel(trainInferenceApplication, trainingMap, validationInferenceApplication, validationMap); } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 1e695a38a..67d58d893 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -18,6 +18,7 @@ package org.linqs.psl.application.learning.weight.gradient; import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.application.learning.weight.TrainingMap; import org.linqs.psl.application.learning.weight.WeightLearningApplication; import org.linqs.psl.application.learning.weight.gradient.batchgenerator.BatchGenerator; import org.linqs.psl.config.Options; @@ -72,11 +73,12 @@ public static enum GDExtension { protected int trainingEvaluationComputePeriod; protected SimpleTermStore trainFullTermStore; + protected TrainingMap fullTrainingMap; protected List trainFullDeepModelPredicates; protected TermState[] trainFullMAPTermState; protected float[] trainFullMAPAtomValueState; - double currentTrainingEvaluationMetric; - double bestTrainingEvaluationMetric; + protected double currentTrainingEvaluationMetric; + protected double bestTrainingEvaluationMetric; protected boolean fullMAPEvaluationBreak; protected int fullMAPEvaluationPatience; protected int lastTrainingImprovementEpoch; @@ -136,6 +138,7 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database trainingEvaluationComputePeriod = Options.WLA_GRADIENT_DESCENT_TRAINING_COMPUTE_PERIOD.getInt(); trainFullTermStore = null; + fullTrainingMap = null; trainFullDeepModelPredicates = null; trainFullMAPTermState = null; trainFullMAPAtomValueState = null; @@ -217,7 +220,9 @@ protected void validateState() { } protected void initializeFullModels() { - this.trainFullTermStore = (SimpleTermStore)trainInferenceApplication.getTermStore(); + trainFullTermStore = (SimpleTermStore)trainInferenceApplication.getTermStore(); + + fullTrainingMap = trainingMap; trainFullDeepModelPredicates = deepModelPredicates; @@ -234,7 +239,7 @@ protected void initializeBatches() { } batchGenerator = BatchGenerator.getBatchGenerator(Options.WLA_GRADIENT_DESCENT_BATCH_GENERATOR.getString(), - trainInferenceApplication, trainFullTermStore, deepPredicates); + trainInferenceApplication, trainFullTermStore, deepPredicates, trainTruthDatabase.getAtomStore()); batchGenerator.generateBatches(); } @@ -471,7 +476,7 @@ protected void measureEpochParameterMovement() { protected void setFullModel() { trainInferenceApplication.setTermStore(trainFullTermStore); - + trainingMap = fullTrainingMap; trainMAPTermState = trainFullMAPTermState; trainMAPAtomValueState = trainFullMAPAtomValueState; @@ -488,6 +493,7 @@ protected void setBatch(int batch) { List batchDeepModelPredicates = batchGenerator.getBatchDeepModelPredicates(batch); trainInferenceApplication.setTermStore(batchTermStore); + trainingMap = batchGenerator.getBatchTrainingMap(batch); trainMAPTermState = batchMAPTermStates.get(batch); trainMAPAtomValueState = batchMAPAtomValueStates.get(batch); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java index 67599caef..fbaed727e 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java @@ -18,6 +18,8 @@ package org.linqs.psl.application.learning.weight.gradient.batchgenerator; import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.application.learning.weight.TrainingMap; +import org.linqs.psl.database.AtomStore; import org.linqs.psl.model.deep.DeepModelPredicate; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.reasoner.term.ReasonerTerm; @@ -37,21 +39,28 @@ public abstract class BatchGenerator { protected InferenceApplication inferenceApplication; protected SimpleTermStore fullTermStore; + protected AtomStore fullTruthAtomStore; protected List deepPredicates; protected List> batchTermStores; + protected List batchTruthAtomStores; + protected List batchTrainingMaps; protected List> batchDeepModelPredicates; protected ArrayList batchPermutation; protected int currentBatchPermutationIndex; - public BatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, List deepPredicates) { + public BatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, + List deepPredicates, AtomStore fullTruthAtomStore) { this.inferenceApplication = inferenceApplication; this.fullTermStore = fullTermStore; + this.fullTruthAtomStore = fullTruthAtomStore; this.deepPredicates = deepPredicates; batchTermStores = new ArrayList>(); + batchTruthAtomStores = new ArrayList(); + batchTrainingMaps = new ArrayList(); batchDeepModelPredicates = new ArrayList>(); batchPermutation = new ArrayList(); @@ -85,11 +94,33 @@ public List getBatchDeepModelPredicates(int index) { return batchDeepModelPredicates.get(index); } + public List getBatchTruthAtomStores() { + return batchTruthAtomStores; + } + + public AtomStore getBatchTruthAtomStore(int index) { + return batchTruthAtomStores.get(index); + } + + public List getBatchTrainingMaps() { + return batchTrainingMaps; + } + + public TrainingMap getBatchTrainingMap(int index) { + return batchTrainingMaps.get(index); + } + public void generateBatches() { clear(); - generateBatchTermStores(); + generateBatchesInternal(); + + // Generate batch training maps. + for (int i = 0; i < numBatchTermStores(); i++) { + batchTrainingMaps.add(new TrainingMap(batchTermStores.get(i).getAtomStore(), batchTruthAtomStores.get(i))); + } + // Generate batch deep model predicates. for (int i = 0; i < numBatchTermStores(); i++) { SimpleTermStore batchTermStore = batchTermStores.get(i); batchDeepModelPredicates.add(new ArrayList()); @@ -101,12 +132,13 @@ public void generateBatches() { } } + // Generate initial batch permutation. for (int i = 0; i < numBatchTermStores(); i++) { batchPermutation.add(i); } } - public abstract void generateBatchTermStores(); + protected abstract void generateBatchesInternal(); /** * Permute the order of the batches. @@ -184,7 +216,7 @@ public void close() { */ public static BatchGenerator getBatchGenerator(String name, InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, - List deepPredicates) { + List deepPredicates, AtomStore fullTruthAtomStore) { String className = Reflection.resolveClassName(name); if (className == null) { throw new IllegalArgumentException("Could not find class: " + name); @@ -201,14 +233,14 @@ public static BatchGenerator getBatchGenerator(String name, InferenceApplication Constructor constructor = null; try { - constructor = classObject.getConstructor(InferenceApplication.class, SimpleTermStore.class, List.class); + constructor = classObject.getConstructor(InferenceApplication.class, SimpleTermStore.class, List.class, AtomStore.class); } catch (NoSuchMethodException ex) { throw new IllegalArgumentException("No suitable constructor found for batch generator: " + className + ".", ex); } BatchGenerator batchGenerator = null; try { - batchGenerator = constructor.newInstance(inferenceApplication, fullTermStore, deepPredicates); + batchGenerator = constructor.newInstance(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore); } catch (InstantiationException ex) { throw new RuntimeException("Unable to instantiate weight learner (" + className + ")", ex); } catch (IllegalAccessException ex) { diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/ConnectedComponentBatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/ConnectedComponentBatchGenerator.java index bc7c4179c..fe7362e96 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/ConnectedComponentBatchGenerator.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/ConnectedComponentBatchGenerator.java @@ -35,17 +35,19 @@ public class ConnectedComponentBatchGenerator extends BatchGenerator { private final int batchSize; - public ConnectedComponentBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, List deepPredicates) { - super(inferenceApplication, fullTermStore, deepPredicates); + public ConnectedComponentBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, + List deepPredicates, AtomStore fullTruthAtomStore) { + super(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore); batchSize = Options.WLA_CONNECTED_COMPONENT_BATCH_SIZE.getInt(); } @Override - public void generateBatchTermStores() { + public void generateBatchesInternal() { AtomStore fullAtomStore = fullTermStore.getAtomStore(); AtomStore batchAtomStore = new AtomStore(); + batchTruthAtomStores.add(new AtomStore()); SimpleTermStore batchTermStore = (SimpleTermStore) inferenceApplication.createTermStore(); batchTermStore.setAtomStore(batchAtomStore); @@ -58,6 +60,7 @@ public void generateBatchTermStores() { batchNumComponents = 0; batchAtomStore = new AtomStore(); + batchTruthAtomStores.add(new AtomStore()); batchTermStore = (SimpleTermStore) inferenceApplication.createTermStore(); batchTermStore.setAtomStore(batchAtomStore); } @@ -71,6 +74,11 @@ public void generateBatchTermStores() { GroundAtom atom = fullAtomStore.getAtom(originalAtomIndexes[i]); if (!batchAtomStore.hasAtom(atom)) { batchAtomStore.addAtom(atom.copy()); + + // Add the atom to the truth atom store if it has a truth atom. + if (fullTruthAtomStore.hasAtom(atom)) { + batchTruthAtomStores.get(batchTruthAtomStores.size() - 1).addAtom(fullTruthAtomStore.getAtom(fullTruthAtomStore.getAtomIndex(atom))); + } } newAtomIndexes[i] = batchAtomStore.getAtomIndex(atom); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/FullBatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/FullBatchGenerator.java index 7e2451989..75face8e6 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/FullBatchGenerator.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/FullBatchGenerator.java @@ -18,6 +18,7 @@ package org.linqs.psl.application.learning.weight.gradient.batchgenerator; import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.database.AtomStore; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.reasoner.term.SimpleTermStore; @@ -30,12 +31,14 @@ */ public class FullBatchGenerator extends BatchGenerator { - public FullBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, List deepPredicates) { - super(inferenceApplication, fullTermStore, deepPredicates); + public FullBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, + List deepPredicates, AtomStore fullTruthAtomStore) { + super(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore); } @Override - public void generateBatchTermStores() { + public void generateBatchesInternal() { batchTermStores.add(fullTermStore.copy()); + batchTruthAtomStores.add(fullTruthAtomStore.copy()); } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/NeuralBatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/NeuralBatchGenerator.java index 172d8556f..58712561b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/NeuralBatchGenerator.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/NeuralBatchGenerator.java @@ -18,6 +18,7 @@ package org.linqs.psl.application.learning.weight.gradient.batchgenerator; import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.database.AtomStore; import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.reasoner.term.SimpleTermStore; @@ -34,8 +35,9 @@ public class NeuralBatchGenerator extends BatchGenerator { int batchCount; int numBatches; - public NeuralBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, List deepPredicates) { - super(inferenceApplication, fullTermStore, deepPredicates); + public NeuralBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore fullTermStore, + List deepPredicates, AtomStore fullTruthAtomStore) { + super(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore); assert deepPredicates.size() >= 1; @@ -59,8 +61,9 @@ public int epochStart() { } @Override - public void generateBatchTermStores() { + public void generateBatchesInternal() { batchTermStores.add(fullTermStore.copy()); + batchTruthAtomStores.add(fullTruthAtomStore.copy()); } @Override diff --git a/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/Evaluator.java b/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/Evaluator.java index 6d394d270..3bed83a03 100644 --- a/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/Evaluator.java +++ b/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/Evaluator.java @@ -131,7 +131,7 @@ public double getNormalizedMaxRepMetric() { * A convenience call for those who don't want to create a training map directly. */ public void compute(Database rvDB, Database truthDB, StandardPredicate predicate) { - TrainingMap map = new TrainingMap(rvDB, truthDB); + TrainingMap map = new TrainingMap(rvDB.getAtomStore(), truthDB.getAtomStore()); compute(map, predicate); } diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/TrainingMapTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/TrainingMapTest.java index e35dfa729..ac01b7f6a 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/TrainingMapTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/TrainingMapTest.java @@ -109,7 +109,7 @@ public void setUp() { targetsDatabase = dataStore.getDatabase(targetOpenPartition, targetClosedPartition); truthDatabase = dataStore.getDatabase(truthOpenPartition, truthClosedPartition); - trainingMap = new TrainingMap(targetsDatabase, truthDatabase); + trainingMap = new TrainingMap(targetsDatabase.getAtomStore(), truthDatabase.getAtomStore()); } @After diff --git a/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluatorTest.java b/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluatorTest.java index d27c0bcb4..dde4cead2 100644 --- a/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluatorTest.java +++ b/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluatorTest.java @@ -76,7 +76,7 @@ public void setUp() { Database results = dataStore.getDatabase(targetPartition); Database truth = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates()); - trainingMap = new TrainingMap(results, truth); + trainingMap = new TrainingMap(results.getAtomStore(), truth.getAtomStore()); // Since we only need the map, we can close all the databases. results.close(); diff --git a/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/EvaluatorTest.java b/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/EvaluatorTest.java index 5b4473c9a..47ae294d2 100644 --- a/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/EvaluatorTest.java +++ b/psl-core/src/test/java/org/linqs/psl/evaluation/statistics/EvaluatorTest.java @@ -84,7 +84,7 @@ protected void init(float[] predictions, float[] truth) { Database resultsDB = dataStore.getDatabase(targetPartition); Database truthDB = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates()); - trainingMap = new TrainingMap(resultsDB, truthDB); + trainingMap = new TrainingMap(resultsDB.getAtomStore(), truthDB.getAtomStore()); // Since we only need the map, we can close all the databases. resultsDB.close();