Skip to content

Commit

Permalink
Evaluation policy gradient.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Jan 29, 2024
1 parent e22079b commit c1e95da
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 376 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ public double inference(boolean commitAtoms, boolean reset, List<EvaluationInsta
trainingMap = new TrainingMap(database.getAtomStore(), truthDatabase.getAtomStore());
}

log.info("Beginning inference.");
double objective = internalInference(evaluations, trainingMap);
log.info("Inference complete.");
atomsCommitted = false;

// Commits the RandomVariableAtoms back to the Database.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ public void initGroundModel(InferenceApplication trainInferenceApplication, Trai
deepModelPredicates.add(((DeepPredicate)predicate).getDeepModel());

DeepModelPredicate validationDeepModelPredicate = ((DeepPredicate)predicate).getDeepModel().copy();
validationDeepModelPredicate.setAtomStore(validationInferenceApplication.getDatabase().getAtomStore(), true);
log.trace("Initializing Validation Deep Model Predicate: " + validationDeepModelPredicate.toString());
validationDeepModelPredicate.setAtomStore(validationInferenceApplication.getTermStore().getAtomStore(), true);
validationDeepModelPredicates.add(validationDeepModelPredicate);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,12 +528,11 @@ protected void setBatch(int batch) {

protected void setValidationModel() {
// Set to validation deep model predicates.
// Note predict is not called here and should be called after the batch is set.
for (int i = 0; i < deepPredicates.size(); i++) {
DeepPredicate deepPredicate = deepPredicates.get(i);
deepPredicate.setDeepModel(validationDeepModelPredicates.get(i));
}

DeepPredicate.predictAllDeepPredicates();
}

protected void runTrainingEvaluation(int epoch) {
Expand Down Expand Up @@ -577,6 +576,7 @@ protected void runTrainingEvaluation(int epoch) {

protected void runValidationEvaluation(int epoch) {
setValidationModel();
DeepPredicate.predictAllDeepPredicates();

log.trace("Running Validation Inference.");
computeMAPStateWithWarmStart(validationInferenceApplication, validationMAPTermState, validationMAPAtomValueState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ public enum PolicyUpdate {
}

public enum RewardFunction {
NEGATIVE_LOSS,
NEGATIVE_LOSS_SQUARED,
INVERSE_LOSS
EVALUATION
}

private final DeepAtomPolicyDistribution deepAtomPolicyDistribution;
private final PolicyUpdate policyUpdate;
private final RewardFunction rewardFunction;
protected final RewardFunction rewardFunction;
private float valueFunction;
private float[] actionValueFunction;

private int numSamples;
protected int[] actionSampleCounts;
Expand All @@ -60,11 +60,11 @@ public enum RewardFunction {
protected List<float[]> batchLatentInferenceAtomValueStates;
protected float[] rvLatentEnergyGradient;
protected float[] deepLatentEnergyGradient;
protected float[] deepSupervisedLossGradient;
protected float[] deepPolicyGradient;

protected float energyLossCoefficient;

protected float MAPStateSupervisedLoss;
protected float MAPStateEvaluation;

protected float mapEnergy;
protected float[] mapIncompatibility;
Expand All @@ -76,6 +76,8 @@ public PolicyGradient(List<Rule> rules, Database trainTargetDatabase, Database t
deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase());
policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase());
rewardFunction = RewardFunction.valueOf(Options.POLICY_GRADIENT_REWARD_FUNCTION.getString().toUpperCase());
valueFunction = 0.0f;
actionValueFunction = null;

numSamples = Options.POLICY_GRADIENT_NUM_SAMPLES.getInt();
actionSampleCounts = null;
Expand All @@ -89,11 +91,11 @@ public PolicyGradient(List<Rule> rules, Database trainTargetDatabase, Database t
batchLatentInferenceAtomValueStates = new ArrayList<float[]>();
rvLatentEnergyGradient = null;
deepLatentEnergyGradient = null;
deepSupervisedLossGradient = null;
deepPolicyGradient = null;

energyLossCoefficient = Options.POLICY_GRADIENT_ENERGY_LOSS_COEFFICIENT.getFloat();

MAPStateSupervisedLoss = Float.POSITIVE_INFINITY;
MAPStateEvaluation = Float.NEGATIVE_INFINITY;

mapEnergy = Float.POSITIVE_INFINITY;
mapIncompatibility = new float[mutableRules.size()];
Expand All @@ -108,11 +110,11 @@ protected void initForLearning() {
}
}

protected abstract float computeSupervisedLoss();
protected abstract float computeReward();

@Override
protected float computeLearningLoss() {
return MAPStateSupervisedLoss + energyLossCoefficient * latentInferenceEnergy;
return (float) ((evaluation.getNormalizedMaxRepMetric() - MAPStateEvaluation) + energyLossCoefficient * latentInferenceEnergy);
}

@Override
Expand All @@ -132,8 +134,8 @@ protected void initializeGradients() {

rvLatentEnergyGradient = new float[trainFullMAPAtomValueState.length];
deepLatentEnergyGradient = new float[trainFullMAPAtomValueState.length];
deepSupervisedLossGradient = new float[trainFullMAPAtomValueState.length];

deepPolicyGradient = new float[trainFullMAPAtomValueState.length];
actionValueFunction = new float[trainFullMAPAtomValueState.length];
actionSampleCounts = new int[trainFullMAPAtomValueState.length];
}

Expand All @@ -143,8 +145,8 @@ protected void resetGradients() {

Arrays.fill(rvLatentEnergyGradient, 0.0f);
Arrays.fill(deepLatentEnergyGradient, 0.0f);
Arrays.fill(deepSupervisedLossGradient, 0.0f);

Arrays.fill(deepPolicyGradient, 0.0f);
Arrays.fill(actionValueFunction, 0.0f);
Arrays.fill(actionSampleCounts, 0);
}

Expand All @@ -164,65 +166,90 @@ protected void computeIterationStatistics() {

computeMAPInferenceStatistics();

MAPStateSupervisedLoss = computeSupervisedLoss();
MAPStateEvaluation = computeReward();

computeLatentInferenceStatistics();

// Save the initial deep model predictions to reset the deep atom values after computing iteration statistics
// and to compute action probabilities.
System.arraycopy(atomStore.getAtomValues(), 0, initialDeepAtomValues, 0, atomStore.size());

switch (policyUpdate) {
case REINFORCE:
addREINFORCESupervisedLossGradient(0.0f);
break;
case REINFORCE_BASELINE:
addREINFORCESupervisedLossGradient(MAPStateSupervisedLoss);
break;
default:
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
}
computeValueFunctionEstimates();
computePolicyGradient();
}

private void addREINFORCESupervisedLossGradient(float baseline) {
for (int i = 0; i < numSamples; i++) {
sampleAllDeepAtomValues();
private void computePolicyGradient() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();
for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) {
GroundAtom atom = atomStore.getAtom(atomIndex);

computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState);
// Skip atoms that are not DeepAtoms.
if (!((atom instanceof RandomVariableAtom) && (atom.getPredicate() instanceof DeepPredicate))) {
continue;
}

float supervisedLoss = computeSupervisedLoss();
float reward = 0.0f;
if (actionSampleCounts[atomIndex] == 0) {
deepPolicyGradient[atomIndex] = 0.0f;
continue;
}

switch (rewardFunction) {
case NEGATIVE_LOSS:
reward = 1.0f - supervisedLoss;
switch (policyUpdate) {
case REINFORCE:
deepPolicyGradient[atomIndex] = -1.0f * (actionValueFunction[atomIndex]) / atom.getValue();
break;
case NEGATIVE_LOSS_SQUARED:
reward = (float) Math.pow(1.0f - supervisedLoss, 2.0f);
break;
case INVERSE_LOSS:
// The inverse loss may result in a reward of infinity.
// Therefore, we add a small constant to the loss to avoid division by zero.
reward = 1.0f / (supervisedLoss + MathUtils.EPSILON_FLOAT);
case REINFORCE_BASELINE:
deepPolicyGradient[atomIndex] = -1.0f * (actionValueFunction[atomIndex] - valueFunction) / atom.getValue();
break;
default:
throw new IllegalArgumentException("Unknown reward function: " + rewardFunction);
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
}
}

clipPolicyGradient();
}

/**
* Clip policy gradient to stabilize learning.
*/
private void clipPolicyGradient() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

float gradientMagnitude = MathUtils.pNorm(deepPolicyGradient, maxGradientNorm);

if (gradientMagnitude > maxGradientMagnitude) {
// log.trace("Clipping policy gradient. Original gradient magnitude: {} exceeds limit: {} in L_{} space.",
// gradientMagnitude, maxGradientMagnitude, maxGradientNorm);
for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) {
deepPolicyGradient[atomIndex] = maxGradientMagnitude * deepPolicyGradient[atomIndex] / gradientMagnitude;
}
}
}

private void computeValueFunctionEstimates() {
valueFunction = 0.0f;

addPolicyScoreGradient(reward - baseline);
for (int i = 0; i < numSamples; i++) {
sampleAllDeepAtomValues();

computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState);

float reward = computeReward();
addActionValue(reward);
valueFunction += reward;

resetAllDeepAtomValues();
}

for (int i = 0; i < deepSupervisedLossGradient.length; i++) {
// Average the value functions.
valueFunction /= numSamples;

for (int i = 0; i < actionValueFunction.length; i++) {
if (actionSampleCounts[i] == 0) {
deepSupervisedLossGradient[i] = 0.0f;
actionValueFunction[i] = 0.0f;
continue;
}

// log.trace("Atom: {} Deep Supervised Loss Gradient: {}",
// trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i] / actionSampleCounts[i]);
deepSupervisedLossGradient[i] /= actionSampleCounts[i];
actionValueFunction[i] /= actionSampleCounts[i];
}
}

Expand Down Expand Up @@ -338,17 +365,11 @@ protected void addLearningLossWeightGradient() {
protected void addTotalAtomGradient() {
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
rvGradient[i] = energyLossCoefficient * rvLatentEnergyGradient[i];
deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepSupervisedLossGradient[i];

if (trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).getPredicate() instanceof DeepPredicate) {
log.trace("Atom: {} deepLatentEnergyGradient: {}, deepSupervisedLossGradient: {}",
trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(),
deepLatentEnergyGradient[i], deepSupervisedLossGradient[i]);
}
deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepPolicyGradient[i];
}
}

private void addPolicyScoreGradient(float reward) {
private void addActionValue(float reward) {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();
for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) {
GroundAtom atom = atomStore.getAtom(atomIndex);
Expand All @@ -360,25 +381,24 @@ private void addPolicyScoreGradient(float reward) {

switch (deepAtomPolicyDistribution) {
case CATEGORICAL:
addCategoricalPolicyScoreGradient(atomIndex, (RandomVariableAtom) atom, reward);
addCategoricalActionValue(atomIndex, (RandomVariableAtom) atom, reward);
break;
default:
throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution);
}
}
}

private void addCategoricalPolicyScoreGradient(int atomIndex, RandomVariableAtom atom, float reward) {
private void addCategoricalActionValue(int atomIndex, RandomVariableAtom atom, float reward) {
// Skip atoms not selected by the policy.
if (atom.getValue() == 0.0f) {
if (MathUtils.isZero(atom.getValue())) {
return;
}

switch (policyUpdate) {
case REINFORCE:
case REINFORCE_BASELINE:
// The initialDeepAtomValues are the action probabilities.
deepSupervisedLossGradient[atomIndex] -= reward / initialDeepAtomValues[atomIndex];
actionValueFunction[atomIndex] += reward;
break;
default:
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
Expand Down

This file was deleted.

Loading

0 comments on commit c1e95da

Please sign in to comment.