Skip to content

Commit

Permalink
Use batch atom stores to create batch training maps.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Nov 28, 2023
1 parent 7d0d604 commit 936aa5d
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public double inference(boolean commitAtoms, boolean reset, List<EvaluationInsta

TrainingMap trainingMap = null;
if (truthDatabase != null && evaluations != null && evaluations.size() > 0) {
trainingMap = new TrainingMap(database, truthDatabase);
trainingMap = new TrainingMap(database.getAtomStore(), truthDatabase.getAtomStore());
}

log.info("Beginning inference.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
*/
package org.linqs.psl.application.learning.weight;

import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.FunctionalPredicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Logger;

Expand Down Expand Up @@ -84,27 +84,27 @@ public class TrainingMap {
/**
* Initializes the training map of RandomVariableAtoms ObservedAtoms.
*
* @param targetDatabase the database containing the RandomVariableAtoms (any other atom types are ignored)
* @param truthDatabase the database containing matching ObservedAtoms
* @param targetAtomStore the atomstore containing the RandomVariableAtoms (any other atom types are ignored)
* @param truthAtomStore the atomstore containing matching ObservedAtoms
*/
public TrainingMap(Database targetDatabase, Database truthDatabase) {
labelMap = new HashMap<RandomVariableAtom, ObservedAtom>(targetDatabase.getAtomStore().size());
public TrainingMap(AtomStore targetAtomStore, AtomStore truthAtomStore) {
labelMap = new HashMap<RandomVariableAtom, ObservedAtom>(targetAtomStore.size());
observedMap = new HashMap<ObservedAtom, ObservedAtom>();
latentVariables = new ArrayList<RandomVariableAtom>();
missingLabels = new ArrayList<ObservedAtom>();
missingTargets = new ArrayList<ObservedAtom>();

Set<GroundAtom> seenTruthAtoms = new HashSet<GroundAtom>();

for (GroundAtom targetAtom : targetDatabase.getAtomStore()) {
for (GroundAtom targetAtom : targetAtomStore) {
if (targetAtom.getPredicate() instanceof FunctionalPredicate) {
continue;
}

// Note that hasAtom() will not return true for an unmanaged atom (except pre-cached functional predicates).
GroundAtom truthAtom = null;
if (truthDatabase.getAtomStore().hasAtom(targetAtom.getPredicate(), targetAtom.getArguments())) {
truthAtom = truthDatabase.getAtomStore().getAtom(targetAtom.getPredicate(), targetAtom.getArguments());
if (truthAtomStore.hasAtom(targetAtom)) {
truthAtom = truthAtomStore.getAtom(truthAtomStore.getAtomIndex(targetAtom));
}

// Skip any truth atom that is not observed.
Expand All @@ -129,12 +129,12 @@ public TrainingMap(Database targetDatabase, Database truthDatabase) {
}
}

for (GroundAtom truthAtom : truthDatabase.getAtomStore()) {
for (GroundAtom truthAtom : truthAtomStore) {
if (!(truthAtom instanceof ObservedAtom) || seenTruthAtoms.contains(truthAtom)) {
continue;
}

boolean hasAtom = targetDatabase.getAtomStore().hasAtom(truthAtom.getPredicate(), truthAtom.getArguments());
boolean hasAtom = targetAtomStore.hasAtom(truthAtom);
if (hasAtom) {
// This shouldn't be possible (since we already iterated through the target atoms).
// This means that the target is not cached.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ private void initGroundModel(InferenceApplication trainInferenceApplication, Inf
return;
}

TrainingMap trainingMap = new TrainingMap(trainInferenceApplication.getDatabase(), trainTruthDatabase);
TrainingMap validationMap = new TrainingMap(validationInferenceApplication.getDatabase(), validationTruthDatabase);
TrainingMap trainingMap = new TrainingMap(trainInferenceApplication.getDatabase().getAtomStore(), trainTruthDatabase.getAtomStore());
TrainingMap validationMap = new TrainingMap(validationInferenceApplication.getDatabase().getAtomStore(), validationTruthDatabase.getAtomStore());

initGroundModel(trainInferenceApplication, trainingMap, validationInferenceApplication, validationMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.linqs.psl.application.learning.weight.gradient;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.gradient.batchgenerator.BatchGenerator;
import org.linqs.psl.config.Options;
Expand Down Expand Up @@ -72,11 +73,12 @@ public static enum GDExtension {

protected int trainingEvaluationComputePeriod;
protected SimpleTermStore<? extends ReasonerTerm> trainFullTermStore;
protected TrainingMap fullTrainingMap;
protected List<DeepModelPredicate> trainFullDeepModelPredicates;
protected TermState[] trainFullMAPTermState;
protected float[] trainFullMAPAtomValueState;
double currentTrainingEvaluationMetric;
double bestTrainingEvaluationMetric;
protected double currentTrainingEvaluationMetric;
protected double bestTrainingEvaluationMetric;
protected boolean fullMAPEvaluationBreak;
protected int fullMAPEvaluationPatience;
protected int lastTrainingImprovementEpoch;
Expand Down Expand Up @@ -136,6 +138,7 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database

trainingEvaluationComputePeriod = Options.WLA_GRADIENT_DESCENT_TRAINING_COMPUTE_PERIOD.getInt();
trainFullTermStore = null;
fullTrainingMap = null;
trainFullDeepModelPredicates = null;
trainFullMAPTermState = null;
trainFullMAPAtomValueState = null;
Expand Down Expand Up @@ -217,7 +220,9 @@ protected void validateState() {
}

protected void initializeFullModels() {
this.trainFullTermStore = (SimpleTermStore<? extends ReasonerTerm>)trainInferenceApplication.getTermStore();
trainFullTermStore = (SimpleTermStore<? extends ReasonerTerm>)trainInferenceApplication.getTermStore();

fullTrainingMap = trainingMap;

trainFullDeepModelPredicates = deepModelPredicates;

Expand All @@ -234,7 +239,7 @@ protected void initializeBatches() {
}

batchGenerator = BatchGenerator.getBatchGenerator(Options.WLA_GRADIENT_DESCENT_BATCH_GENERATOR.getString(),
trainInferenceApplication, trainFullTermStore, deepPredicates);
trainInferenceApplication, trainFullTermStore, deepPredicates, trainTruthDatabase.getAtomStore());
batchGenerator.generateBatches();
}

Expand Down Expand Up @@ -471,7 +476,7 @@ protected void measureEpochParameterMovement() {

protected void setFullModel() {
trainInferenceApplication.setTermStore(trainFullTermStore);

trainingMap = fullTrainingMap;
trainMAPTermState = trainFullMAPTermState;
trainMAPAtomValueState = trainFullMAPAtomValueState;

Expand All @@ -488,6 +493,7 @@ protected void setBatch(int batch) {
List<DeepModelPredicate> batchDeepModelPredicates = batchGenerator.getBatchDeepModelPredicates(batch);

trainInferenceApplication.setTermStore(batchTermStore);
trainingMap = batchGenerator.getBatchTrainingMap(batch);
trainMAPTermState = batchMAPTermStates.get(batch);
trainMAPAtomValueState = batchMAPAtomValueStates.get(batch);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.linqs.psl.application.learning.weight.gradient.batchgenerator;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
Expand All @@ -37,21 +39,28 @@
public abstract class BatchGenerator {
protected InferenceApplication inferenceApplication;
protected SimpleTermStore<? extends ReasonerTerm> fullTermStore;
protected AtomStore fullTruthAtomStore;
protected List<DeepPredicate> deepPredicates;

protected List<SimpleTermStore<? extends ReasonerTerm>> batchTermStores;
protected List<AtomStore> batchTruthAtomStores;
protected List<TrainingMap> batchTrainingMaps;
protected List<List<DeepModelPredicate>> batchDeepModelPredicates;

protected ArrayList<Integer> batchPermutation;
protected int currentBatchPermutationIndex;


public BatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore, List<DeepPredicate> deepPredicates) {
public BatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore,
List<DeepPredicate> deepPredicates, AtomStore fullTruthAtomStore) {
this.inferenceApplication = inferenceApplication;
this.fullTermStore = fullTermStore;
this.fullTruthAtomStore = fullTruthAtomStore;
this.deepPredicates = deepPredicates;

batchTermStores = new ArrayList<SimpleTermStore<? extends ReasonerTerm>>();
batchTruthAtomStores = new ArrayList<AtomStore>();
batchTrainingMaps = new ArrayList<TrainingMap>();
batchDeepModelPredicates = new ArrayList<List<DeepModelPredicate>>();
batchPermutation = new ArrayList<Integer>();

Expand Down Expand Up @@ -85,11 +94,33 @@ public List<DeepModelPredicate> getBatchDeepModelPredicates(int index) {
return batchDeepModelPredicates.get(index);
}

public List<AtomStore> getBatchTruthAtomStores() {
return batchTruthAtomStores;
}

public AtomStore getBatchTruthAtomStore(int index) {
return batchTruthAtomStores.get(index);
}

public List<TrainingMap> getBatchTrainingMaps() {
return batchTrainingMaps;
}

public TrainingMap getBatchTrainingMap(int index) {
return batchTrainingMaps.get(index);
}

public void generateBatches() {
clear();

generateBatchTermStores();
generateBatchesInternal();

// Generate batch training maps.
for (int i = 0; i < numBatchTermStores(); i++) {
batchTrainingMaps.add(new TrainingMap(batchTermStores.get(i).getAtomStore(), batchTruthAtomStores.get(i)));
}

// Generate batch deep model predicates.
for (int i = 0; i < numBatchTermStores(); i++) {
SimpleTermStore<? extends ReasonerTerm> batchTermStore = batchTermStores.get(i);
batchDeepModelPredicates.add(new ArrayList<DeepModelPredicate>());
Expand All @@ -101,12 +132,13 @@ public void generateBatches() {
}
}

// Generate initial batch permutation.
for (int i = 0; i < numBatchTermStores(); i++) {
batchPermutation.add(i);
}
}

public abstract void generateBatchTermStores();
protected abstract void generateBatchesInternal();

/**
* Permute the order of the batches.
Expand Down Expand Up @@ -184,7 +216,7 @@ public void close() {
*/
public static BatchGenerator getBatchGenerator(String name, InferenceApplication inferenceApplication,
SimpleTermStore<? extends ReasonerTerm> fullTermStore,
List<DeepPredicate> deepPredicates) {
List<DeepPredicate> deepPredicates, AtomStore fullTruthAtomStore) {
String className = Reflection.resolveClassName(name);
if (className == null) {
throw new IllegalArgumentException("Could not find class: " + name);
Expand All @@ -201,14 +233,14 @@ public static BatchGenerator getBatchGenerator(String name, InferenceApplication

Constructor<? extends BatchGenerator> constructor = null;
try {
constructor = classObject.getConstructor(InferenceApplication.class, SimpleTermStore.class, List.class);
constructor = classObject.getConstructor(InferenceApplication.class, SimpleTermStore.class, List.class, AtomStore.class);
} catch (NoSuchMethodException ex) {
throw new IllegalArgumentException("No suitable constructor found for batch generator: " + className + ".", ex);
}

BatchGenerator batchGenerator = null;
try {
batchGenerator = constructor.newInstance(inferenceApplication, fullTermStore, deepPredicates);
batchGenerator = constructor.newInstance(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore);
} catch (InstantiationException ex) {
throw new RuntimeException("Unable to instantiate weight learner (" + className + ")", ex);
} catch (IllegalAccessException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,19 @@ public class ConnectedComponentBatchGenerator extends BatchGenerator {

private final int batchSize;

public ConnectedComponentBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore, List<DeepPredicate> deepPredicates) {
super(inferenceApplication, fullTermStore, deepPredicates);
public ConnectedComponentBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore,
List<DeepPredicate> deepPredicates, AtomStore fullTruthAtomStore) {
super(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore);

batchSize = Options.WLA_CONNECTED_COMPONENT_BATCH_SIZE.getInt();
}

@Override
public void generateBatchTermStores() {
public void generateBatchesInternal() {
AtomStore fullAtomStore = fullTermStore.getAtomStore();

AtomStore batchAtomStore = new AtomStore();
batchTruthAtomStores.add(new AtomStore());
SimpleTermStore<? extends ReasonerTerm> batchTermStore = (SimpleTermStore<? extends ReasonerTerm>) inferenceApplication.createTermStore();
batchTermStore.setAtomStore(batchAtomStore);

Expand All @@ -58,6 +60,7 @@ public void generateBatchTermStores() {
batchNumComponents = 0;

batchAtomStore = new AtomStore();
batchTruthAtomStores.add(new AtomStore());
batchTermStore = (SimpleTermStore<? extends ReasonerTerm>) inferenceApplication.createTermStore();
batchTermStore.setAtomStore(batchAtomStore);
}
Expand All @@ -71,6 +74,11 @@ public void generateBatchTermStores() {
GroundAtom atom = fullAtomStore.getAtom(originalAtomIndexes[i]);
if (!batchAtomStore.hasAtom(atom)) {
batchAtomStore.addAtom(atom.copy());

// Add the atom to the truth atom store if it has a truth atom.
if (fullTruthAtomStore.hasAtom(atom)) {
batchTruthAtomStores.get(batchTruthAtomStores.size() - 1).addAtom(fullTruthAtomStore.getAtom(fullTruthAtomStore.getAtomIndex(atom)));
}
}

newAtomIndexes[i] = batchAtomStore.getAtomIndex(atom);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.linqs.psl.application.learning.weight.gradient.batchgenerator;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;
Expand All @@ -30,12 +31,14 @@
*/
public class FullBatchGenerator extends BatchGenerator {

public FullBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore, List<DeepPredicate> deepPredicates) {
super(inferenceApplication, fullTermStore, deepPredicates);
public FullBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore,
List<DeepPredicate> deepPredicates, AtomStore fullTruthAtomStore) {
super(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore);
}

@Override
public void generateBatchTermStores() {
public void generateBatchesInternal() {
batchTermStores.add(fullTermStore.copy());
batchTruthAtomStores.add(fullTruthAtomStore.copy());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.linqs.psl.application.learning.weight.gradient.batchgenerator;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;
Expand All @@ -34,8 +35,9 @@ public class NeuralBatchGenerator extends BatchGenerator {
int batchCount;
int numBatches;

public NeuralBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore, List<DeepPredicate> deepPredicates) {
super(inferenceApplication, fullTermStore, deepPredicates);
public NeuralBatchGenerator(InferenceApplication inferenceApplication, SimpleTermStore<? extends ReasonerTerm> fullTermStore,
List<DeepPredicate> deepPredicates, AtomStore fullTruthAtomStore) {
super(inferenceApplication, fullTermStore, deepPredicates, fullTruthAtomStore);

assert deepPredicates.size() >= 1;

Expand All @@ -59,8 +61,9 @@ public int epochStart() {
}

@Override
public void generateBatchTermStores() {
public void generateBatchesInternal() {
batchTermStores.add(fullTermStore.copy());
batchTruthAtomStores.add(fullTruthAtomStore.copy());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public double getNormalizedMaxRepMetric() {
* A convenience call for those who don't want to create a training map directly.
*/
public void compute(Database rvDB, Database truthDB, StandardPredicate predicate) {
TrainingMap map = new TrainingMap(rvDB, truthDB);
TrainingMap map = new TrainingMap(rvDB.getAtomStore(), truthDB.getAtomStore());
compute(map, predicate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public void setUp() {
targetsDatabase = dataStore.getDatabase(targetOpenPartition, targetClosedPartition);
truthDatabase = dataStore.getDatabase(truthOpenPartition, truthClosedPartition);

trainingMap = new TrainingMap(targetsDatabase, truthDatabase);
trainingMap = new TrainingMap(targetsDatabase.getAtomStore(), truthDatabase.getAtomStore());
}

@After
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void setUp() {
Database results = dataStore.getDatabase(targetPartition);
Database truth = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates());

trainingMap = new TrainingMap(results, truth);
trainingMap = new TrainingMap(results.getAtomStore(), truth.getAtomStore());

// Since we only need the map, we can close all the databases.
results.close();
Expand Down
Loading

0 comments on commit 936aa5d

Please sign in to comment.