From 59156d2a6f5bcaa4ebede20c1cd3cc5bd4691399 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Wed, 10 Jul 2024 04:33:58 +0000 Subject: [PATCH] WIP --- compiler_opt/es/blackbox_learner.py | 61 +-------------- compiler_opt/es/es_trainer.py | 9 ++- compiler_opt/es/es_trainer_lib.py | 75 ++++++++++--------- .../es/gin_configs/blackbox_learner.gin | 21 ++++++ compiler_opt/es/gin_configs/regalloc.gin | 17 +++++ compiler_opt/rl/compilation_runner.py | 2 +- compiler_opt/rl/trace_data_collector.py | 24 +++--- 7 files changed, 104 insertions(+), 105 deletions(-) create mode 100644 compiler_opt/es/gin_configs/blackbox_learner.gin create mode 100644 compiler_opt/es/gin_configs/regalloc.gin diff --git a/compiler_opt/es/blackbox_learner.py b/compiler_opt/es/blackbox_learner.py index 289d3e57..211e4911 100644 --- a/compiler_opt/es/blackbox_learner.py +++ b/compiler_opt/es/blackbox_learner.py @@ -63,15 +63,6 @@ class BlackboxLearnerConfig: # 0 means all num_top_directions: int - # How many IR files to try a single perturbation on? - num_ir_repeats_within_worker: int - - # How many times should we reuse IR to test different policies? - num_ir_repeats_across_worker: int - - # How many IR files to sample from the test corpus at each iteration - num_exact_evals: int - # How many perturbations to attempt at each perturbation total_num_perturbations: int @@ -128,7 +119,6 @@ class BlackboxLearner: def __init__(self, blackbox_opt: blackbox_optimizers.BlackboxOptimizer, - sampler: corpus.Corpus, tf_policy_path: str, output_dir: str, policy_saver_fn: PolicySaverCallableType, @@ -141,7 +131,6 @@ def __init__(self, Args: blackbox_opt: the blackbox optimizer to use - train_sampler: corpus_sampler for training data. tf_policy_path: where to write the tf policy output_dir: the directory to write all outputs policy_saver_fn: function to save a policy to cns @@ -152,7 +141,6 @@ def __init__(self, deadline: the deadline in seconds for requests to the inlining server. """ self._blackbox_opt = blackbox_opt - self._sampler = sampler self._tf_policy_path = tf_policy_path self._output_dir = output_dir self._policy_saver_fn = policy_saver_fn @@ -162,10 +150,6 @@ def __init__(self, self._deadline = deadline self._seed = seed - # While we're waiting for the ES requests, we can - # collect samples for the next round of training. - self._samples = [] - self._summary_writer = tf.summary.create_file_writer(output_dir) def _get_perturbations(self) -> List[npt.NDArray[np.float32]]: @@ -248,20 +232,9 @@ def get_model_weights(self) -> npt.NDArray[np.float32]: def _get_results( self, pool: FixedWorkerPool, perturbations: List[bytes]) -> List[concurrent.futures.Future]: - if not self._samples: - for _ in range(self._config.total_num_perturbations): - sample = self._sampler.sample(self._config.num_ir_repeats_within_worker) - self._samples.append(sample) - # add copy of sample for antithetic perturbation pair - if self._config.est_type == ( - blackbox_optimizers.EstimatorType.ANTITHETIC): - self._samples.append(sample) - - compile_args = zip(perturbations, self._samples) - _, futures = buffered_scheduler.schedule_on_worker_pool( - action=lambda w, v: w.compile(v[0], v[1]), - jobs=compile_args, + action=lambda w, v: w.es_compile(params=self._model_weights + v), + jobs=perturbations, worker_pool=pool) not_done = futures @@ -273,27 +246,6 @@ def _get_results( return futures - def _get_policy_as_bytes(self, - perturbation: npt.NDArray[np.float32]) -> bytes: - sm = tf.saved_model.load(self._tf_policy_path) - # devectorize the perturbation - policy_utils.set_vectorized_parameters_for_policy(sm, perturbation) - - with tempfile.TemporaryDirectory() as tmpdir: - sm_dir = os.path.join(tmpdir, 'sm') - tf.saved_model.save(sm, sm_dir, signatures=sm.signatures) - src = os.path.join(self._tf_policy_path, policy_saver.OUTPUT_SIGNATURE) - dst = os.path.join(sm_dir, policy_saver.OUTPUT_SIGNATURE) - tf.io.gfile.copy(src, dst) - - # convert to tflite - tfl_dir = os.path.join(tmpdir, 'tfl') - policy_saver.convert_mlgo_model(sm_dir, tfl_dir) - - # create and return policy - policy_obj = policy_saver.Policy.from_filesystem(tfl_dir) - return policy_obj.policy - def run_step(self, pool: FixedWorkerPool) -> None: """Run a single step of blackbox learning. This does not instantaneously return due to several I/O @@ -308,14 +260,7 @@ def run_step(self, pool: FixedWorkerPool) -> None: p for p in initial_perturbations for p in (p, -p) ] - # convert to bytes for compile job - # TODO: current conversion is inefficient. - # consider doing this on the worker side - perturbations_as_bytes = [] - for perturbation in initial_perturbations: - perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation)) - - results = self._get_results(pool, perturbations_as_bytes) + results = self._get_results(pool, initial_perturbations) rewards = self._get_rewards(results) num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards) diff --git a/compiler_opt/es/es_trainer.py b/compiler_opt/es/es_trainer.py index c0e43807..a509f941 100644 --- a/compiler_opt/es/es_trainer.py +++ b/compiler_opt/es/es_trainer.py @@ -18,6 +18,7 @@ import gin from compiler_opt.es import es_trainer_lib +from compiler_opt.rl import registry _GIN_FILES = flags.DEFINE_multi_string( "gin_files", [], "List of paths to gin configuration files.") @@ -31,10 +32,12 @@ def main(_): _GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False) logging.info(gin.config_str()) - final_weights = es_trainer_lib.train() + problem_config = registry.get_configuration() + final_weights = es_trainer_lib.train( + worker_class=problem_config.get_runner_type()) - logging.info("Final Weights:") - logging.info(", ".join(final_weights)) + logging.info("Training completed.") + # logging.info(", ".join(final_weights)) if __name__ == "__main__": diff --git a/compiler_opt/es/es_trainer_lib.py b/compiler_opt/es/es_trainer_lib.py index 2eaef72d..22ee20ce 100644 --- a/compiler_opt/es/es_trainer_lib.py +++ b/compiler_opt/es/es_trainer_lib.py @@ -14,18 +14,21 @@ # limitations under the License. """Local ES trainer.""" +import tempfile +from typing import Optional from absl import flags, logging import functools import gin import tensorflow as tf import os +from compiler_opt.distributed import worker from compiler_opt.distributed.local import local_worker_manager from compiler_opt.es import blackbox_optimizers from compiler_opt.es import gradient_ascent_optimization_algorithms from compiler_opt.es import blackbox_learner from compiler_opt.es import policy_utils -from compiler_opt.rl import policy_saver, corpus +from compiler_opt.rl import compilation_runner, policy_saver, trace_data_collector POLICY_NAME = "policy" @@ -42,15 +45,8 @@ "grad_reg_type", "ridge", "Regularization method to use with regression gradient.") _GRADIENT_ASCENT_OPTIMIZER_TYPE = flags.DEFINE_string( - "gradient_ascent_optimizer_type", None, + "gradient_ascent_optimizer_type", 'adam', "Gradient ascent optimization algorithm: 'momentum' or 'adam'") -flags.mark_flag_as_required("gradient_ascent_optimizer_type") -_GREEDY = flags.DEFINE_bool( - "greedy", - None, - "Whether to construct a greedy policy (argmax). \ - If False, a sampling-based policy will be used.", - required=True) _MOMENTUM = flags.DEFINE_float( "momentum", 0.0, "Momentum for momentum gradient ascent optimizer.") _OUTPUT_PATH = flags.DEFINE_string("output_path", "", @@ -59,28 +55,47 @@ "pretrained_policy_path", None, "The path of the pretrained policy. If not provided, it will \ construct a new policy with randomly initialized weights.") -_REQUEST_DEADLINE = flags.DEFINE_float( - "request_deadline", 30.0, "Deadline in seconds for requests \ - to the data collection requests.") -_TRAIN_CORPORA = flags.DEFINE_string("train_corpora", "", - "List of paths to training corpora") +_CORPUS_DIR = flags.DEFINE_string("corpus_dir", None, "The path to the corpus to use") + + +class ESWorker(worker.Worker): + + def __init__(self, *, all_gin): + gin.parse_config(all_gin) + policy = policy_utils.create_actor_policy() + saver = policy_saver.PolicySaver({POLICY_NAME: policy}) + self._template_dir = tempfile.mkdtemp() + saver.save(self._template_dir) + + def es_compile(self, params: list[float]) -> float: + with tempfile.TemporaryDirectory() as tempdir: + smdir = os.path.join(tempdir, 'sm') + my_model = tf.saved_model.load( + os.path.join(self._template_dir, POLICY_NAME)) + policy_utils.set_vectorized_parameters_for_policy(my_model, params) + tf.saved_model.save(my_model, smdir, signatures=my_model.signatures) + tflitedir = os.path.join(tempdir, 'tflite') + policy_saver.convert_saved_model( + smdir, os.path.join(tflitedir, policy_saver.TFLITE_MODEL_NAME)) + tf.io.gfile.copy( + os.path.join(self._template_dir, POLICY_NAME, + policy_saver.OUTPUT_SIGNATURE), + os.path.join(tflitedir, policy_saver.OUTPUT_SIGNATURE)) + + print(os.listdir(tflitedir)) + return 1.0 @gin.configurable -def train(additional_compilation_flags=(), - delete_compilation_flags=(), - worker_class=None): +def train(worker_class=None): """Train with ES.""" - if not _TRAIN_CORPORA.value: - raise ValueError("Need to supply nonempty train corpora.") - # Create directories if not tf.io.gfile.isdir(_OUTPUT_PATH.value): tf.io.gfile.makedirs(_OUTPUT_PATH.value) # Construct the policy and upload it - policy = policy_utils.create_actor_policy(greedy=_GREEDY.value) + policy = policy_utils.create_actor_policy() saver = policy_saver.PolicySaver({POLICY_NAME: policy}) # Save the policy @@ -112,17 +127,10 @@ def train(additional_compilation_flags=(), logging.info("Parameter dimension: %s", initial_parameters.shape) logging.info("Initial parameters: %s", initial_parameters) - cps = corpus.create_corpus_for_testing( - location=_TRAIN_CORPORA.value, - elements=[corpus.ModuleSpec(name="smth", size=1)], - additional_flags=additional_compilation_flags, - delete_flags=delete_compilation_flags) - # Construct policy saver - saved_policy = policy_utils.create_actor_policy(greedy=True) policy_saver_function = functools.partial( policy_utils.save_policy, - policy=saved_policy, + policy=policy, save_folder=os.path.join(_OUTPUT_PATH.value, "saved_policies")) # Get learner config @@ -203,14 +211,12 @@ def train(additional_compilation_flags=(), logging.info("Initializing blackbox learner.") learner = blackbox_learner.BlackboxLearner( blackbox_opt=blackbox_optimizer, - sampler=cps, tf_policy_path=os.path.join(policy_save_path, POLICY_NAME), output_dir=_OUTPUT_PATH.value, policy_saver_fn=policy_saver_function, model_weights=init_current_input, config=learner_config, - initial_step=init_iteration, - deadline=_REQUEST_DEADLINE.value) + initial_step=init_iteration) if not worker_class: logging.info("No Worker class selected. Stopping.") @@ -220,8 +226,9 @@ def train(additional_compilation_flags=(), learner_config.total_steps) with local_worker_manager.LocalWorkerPoolManager( - worker_class, learner_config.total_num_perturbations, arg="", - kwarg="") as pool: + ESWorker, + learner_config.total_num_perturbations, + all_gin=gin.config_str()) as pool: for _ in range(learner_config.total_steps): learner.run_step(pool) diff --git a/compiler_opt/es/gin_configs/blackbox_learner.gin b/compiler_opt/es/gin_configs/blackbox_learner.gin new file mode 100644 index 00000000..b91915f0 --- /dev/null +++ b/compiler_opt/es/gin_configs/blackbox_learner.gin @@ -0,0 +1,21 @@ +import compiler_opt.es.blackbox_learner +import compiler_opt.rl.gin_external_configurables +import compiler_opt.es.blackbox_optimizers + +# Blackbox learner config +BlackboxLearnerConfig.total_steps = 1 +BlackboxLearnerConfig.total_num_perturbations = 5 +BlackboxLearnerConfig.blackbox_optimizer = %blackbox_optimizers.Algorithm.MONTE_CARLO +BlackboxLearnerConfig.est_type = %blackbox_optimizers.EstimatorType.ANTITHETIC +# BlackboxLearnerConfig.est_type = %blackbox_optimizers.EstimatorType.FORWARD_FD +BlackboxLearnerConfig.fvalues_normalization = True +BlackboxLearnerConfig.hyperparameters_update_method = %blackbox_optimizers.UpdateMethod.NO_METHOD + +BlackboxLearnerConfig.num_top_directions = 0 + +# BlackboxLearnerConfig.precision_parameter = 0.025 +BlackboxLearnerConfig.precision_parameter = 0.5 + +# Try the 0.0005 step size next +BlackboxLearnerConfig.step_size = 0.005 +# BlackboxLearnerConfig.step_size = 0.0005 \ No newline at end of file diff --git a/compiler_opt/es/gin_configs/regalloc.gin b/compiler_opt/es/gin_configs/regalloc.gin new file mode 100644 index 00000000..f1acf227 --- /dev/null +++ b/compiler_opt/es/gin_configs/regalloc.gin @@ -0,0 +1,17 @@ +import compiler_opt.rl.gin_external_configurables +import compiler_opt.rl.regalloc.config +import compiler_opt.rl.regalloc.regalloc_network + +include 'compiler_opt/rl/regalloc/gin_configs/common.gin' + +regalloc.config.get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/regalloc/vocab' +regalloc.config.get_observation_processing_layer_creator.with_sqrt = False +regalloc.config.get_observation_processing_layer_creator.with_z_score_normalization = False + +RegAllocNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate() +RegAllocNetwork.fc_layer_params=(80, 40) +RegAllocNetwork.dropout_layer_params=None +RegAllocNetwork.activation_fn=@tf.keras.activations.relu + +policy_utils.create_actor_policy.actor_network_ctor = @regalloc_network.RegAllocNetwork +policy_utils.create_actor_policy.greedy = True diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index a3f5c0b4..0a7fd41f 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -208,7 +208,7 @@ def start_cancellable_process( command_env = os.environ.copy() # Disable tensorflow info messages during data collection if _QUIET.value: - command_env['TF_CPP_MIN_LOG_LEVEL'] = '1' + command_env['TF_CPP_MIN_LOG_LEVEL'] = '2' else: logging.info(shlex.join(cmdline)) with subprocess.Popen( diff --git a/compiler_opt/rl/trace_data_collector.py b/compiler_opt/rl/trace_data_collector.py index c2614a26..07e1023e 100644 --- a/compiler_opt/rl/trace_data_collector.py +++ b/compiler_opt/rl/trace_data_collector.py @@ -22,7 +22,7 @@ import shutil -def compile_corpus(corpus_path, output_path, clang_path): +def compile_corpus(corpus_path, output_path, clang_path, tflite_dir = None): with open( os.path.join(corpus_path, 'corpus_description.json'), encoding='utf-8') as corpus_description_handle: @@ -34,20 +34,26 @@ def compile_corpus(corpus_path, output_path, clang_path): module_full_output_path = os.path.join(output_path, module_path) + '.bc.o' pathlib.Path(os.path.dirname(module_full_output_path)).mkdir( parents=True, exist_ok=True) - + module_command_full_path = os.path.join(corpus_path, module_path) + '.cmd' with open(module_command_full_path) as module_command_handle: - module_command_line = tuple(module_command_handle.read().replace(r'{', r'{{').replace(r'}', - r'}}').split('\0')) + module_command_line = tuple(module_command_handle.read().replace( + r'{', r'{{').replace(r'}', r'}}').split('\0')) - command_vector = [ - clang_path] + command_vector = [clang_path] command_vector.extend(module_command_line) - command_vector.extend([module_full_input_path, '-o', module_full_output_path]) + command_vector.extend( + [module_full_input_path, '-o', module_full_output_path]) + + if tflite_dir is not None: + command_vector.extend(['-mllvm', '-regalloc-enable-advisor=development']) + command_vector.extend(['-mllvm', '-ml-regalloc-model=' + tflite_dir]) subprocess.run(command_vector) logging.info( f'Just finished compiling {module_full_output_path} ({module_index + 1}/{len(corpus_description["modules"])})' ) - - shutil.copy(os.path.join(corpus_path, 'corpus_description.json'), os.path.join(output_path, 'corpus_description.json')) + + shutil.copy( + os.path.join(corpus_path, 'corpus_description.json'), + os.path.join(output_path, 'corpus_description.json'))