From 513d50da3662081a0e20ea1222788e8a6718b17d Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Thu, 19 Dec 2024 06:11:07 +0000 Subject: [PATCH] Add a percentage correct metric for BC training (#402) This patch adds a percentage correct for BC training, which makes it a lot easier to interpret how a model is doing rather than just staring at loss values. --- compiler_opt/rl/trainer.py | 27 ++++++++++++++++++++++++++- compiler_opt/rl/trainer_test.py | 24 +++++++++++++++++++++--- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/compiler_opt/rl/trainer.py b/compiler_opt/rl/trainer.py index 10c6654c..94fa88b0 100644 --- a/compiler_opt/rl/trainer.py +++ b/compiler_opt/rl/trainer.py @@ -23,6 +23,7 @@ from compiler_opt.rl import random_net_distillation from tf_agents.agents import tf_agent from tf_agents.policies import policy_loader +from tf_agents import trajectories from tf_agents.utils import common as common_utils from typing import Optional @@ -54,7 +55,8 @@ def __init__( log_interval=100, summary_log_interval=100, summary_export_interval=1000, - summaries_flush_secs=10): + summaries_flush_secs=10, + bc_percentage_correct=False): """Initialize the Trainer object. Args: @@ -70,6 +72,9 @@ def __init__( summary_export_interval: int, the training step interval for exporting to tensorboard. summaries_flush_secs: int, the seconds for flushing to tensorboard. + bc_percentage_correct: bool, whether or not to log the accuracy of the + current batch. This is intended for use during BC training where labels + for the "correct" decision are available. """ self._root_dir = root_dir self._agent = agent @@ -84,6 +89,7 @@ def __init__( self._summary_writer.set_as_default() self._global_step = tf.compat.v1.train.get_or_create_global_step() + self._bc_percentage_correct = bc_percentage_correct # Initialize agent and trajectory replay. # Wrap training and trajectory replay in a tf.function to make it much @@ -118,6 +124,7 @@ def _initialize_metrics(self): self._data_action_mean = tf.keras.metrics.Mean() self._data_reward_mean = tf.keras.metrics.Mean() self._num_trajectories = tf.keras.metrics.Sum() + self._percentage_correct = tf.keras.metrics.Accuracy() def _update_metrics(self, experience, monitor_dict): """Updates metrics and exports to Tensorboard.""" @@ -130,6 +137,16 @@ def _update_metrics(self, experience, monitor_dict): experience.reward, sample_weight=is_action) self._num_trajectories.update_state(experience.is_first()) + # Compute the accuracy if we are BC training. + if self._bc_percentage_correct: + experience_time_step = trajectories.TimeStep(experience.step_type, + experience.reward, + experience.discount, + experience.observation) + policy_actions = self._agent.policy.action(experience_time_step) + self._percentage_correct.update_state(experience.action, + policy_actions.action) + # Check earlier rather than later if we should record summaries. # TF also checks it, but much later. Needed to avoid looping through # the dict so gave the if a bigger scope @@ -147,6 +164,11 @@ def _update_metrics(self, experience, monitor_dict): name='num_trajectories', data=self._num_trajectories.result(), step=self._global_step) + if self._bc_percentage_correct: + tf.summary.scalar( + name='percentage_correct', + data=self._percentage_correct.result(), + step=self._global_step) for name_scope, d in monitor_dict.items(): with tf.name_scope(name_scope + '/'): @@ -159,6 +181,7 @@ def _update_metrics(self, experience, monitor_dict): def _reset_metrics(self): """Reset num_trajectories.""" self._num_trajectories.reset_states() + self._percentage_correct.reset_state() def _log_experiment(self, loss): """Log training info.""" @@ -204,6 +227,8 @@ def train(self, dataset_iter, monitor_dict, num_iterations: int): loss = self._agent.train(experience) + self._percentage_correct.reset_state() + self._update_metrics(experience, monitor_dict) self._log_experiment(loss.loss) self._save_checkpoint() diff --git a/compiler_opt/rl/trainer_test.py b/compiler_opt/rl/trainer_test.py index d42f4c83..16044c65 100644 --- a/compiler_opt/rl/trainer_test.py +++ b/compiler_opt/rl/trainer_test.py @@ -18,7 +18,7 @@ import tensorflow as tf from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent -from tf_agents.networks import q_rnn_network +from tf_agents.networks import q_network from tf_agents.specs import tensor_spec from tf_agents.trajectories import time_step from tf_agents.trajectories import trajectory @@ -66,10 +66,9 @@ def setUp(self): minimum=0, maximum=1, name='inlining_decision') - self._network = q_rnn_network.QRnnNetwork( + self._network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec, - lstm_size=(40,), preprocessing_layers={ 'callee_users': tf.keras.layers.Lambda(lambda x: x) }) @@ -154,6 +153,25 @@ def test_training_metrics(self): self.assertEqual(2, test_trainer._data_reward_mean.result().numpy()) self.assertEqual(90, test_trainer._num_trajectories.result().numpy()) + def test_training_metrics_bc(self): + test_agent = behavioral_cloning_agent.BehavioralCloningAgent( + self._time_step_spec, + self._action_spec, + self._network, + tf.compat.v1.train.AdamOptimizer(), + num_outer_dims=2) + test_trainer = trainer.Trainer( + root_dir=self.get_temp_dir(), + agent=test_agent, + summary_log_interval=1, + bc_percentage_correct=True) + + dataset_iter = _create_test_data(batch_size=3, sequence_length=3) + monitor_dict = {'default': {'test': 1}} + test_trainer.train(dataset_iter, monitor_dict, num_iterations=10) + + self.assertLess(0.1, test_trainer._percentage_correct.result().numpy()) + def test_inference(self): test_agent = behavioral_cloning_agent.BehavioralCloningAgent( self._time_step_spec,