From cea85761e597210a5301e4cb2a44161f11892eda Mon Sep 17 00:00:00 2001 From: Brian Gunnarson Date: Fri, 17 Mar 2023 11:22:34 -0700 Subject: [PATCH] fix bugs for stop-workers and schema validation fix bugs for stop-workers and schema validation fix all flags for stop-workers fix small bug with schema validation modify CHANGELOG run fix-style add myself to contributors list decouple celery logic from main remove unused import make changes Luc requested in PR change worker assignment to use sets --- CHANGELOG.md | 13 ++ CONTRIBUTORS | 1 + merlin/main.py | 20 +- merlin/router.py | 30 +-- merlin/spec/merlinspec.json | 5 +- merlin/spec/specification.py | 145 ++++++++++---- merlin/study/celeryadapter.py | 278 +++++++++++++++++--------- tests/integration/test_definitions.py | 3 +- 8 files changed, 342 insertions(+), 153 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aa643c8d..5ac7348d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Pip wheel wasn't including .sh files for merlin examples - The learn.py script in the openfoam_wf* examples will now create the missing Energy v Lidspeed plot +- Fixed the flags associated with the `stop-workers` command (--spec, --queues, --workers) +- Fixed the --step flag for the `run-workers` command ### Added - Now loads np.arrays of dtype='object', allowing mix-type sample npy @@ -22,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added the --distributed and --display-table flags to run_tests.py - --distributed: only run distributed tests - --display-tests: displays a table of all existing tests and the id associated with each test +- Added the --disable-logs flag to the `run-workers` command +- Merlin will now assign `default_worker` to any step not associated with a worker +- Added `get_step_worker_map()` as a method in `specification.py` ### Changed - Changed celery_regex to celery_slurm_regex in test_definitions.py @@ -29,6 +34,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Test values are now dictionaries rather than tuples - Stopped using `subprocess.Popen()` and `subprocess.communicate()` to run tests and now instead use `subprocess.run()` for simplicity and to keep things up-to-date with the latest subprocess release (`run()` will call `Popen()` and `communicate()` under the hood so we don't have to handle that anymore) - Rewrote the README in the integration tests folder to explain the new integration test format +- Reformatted `start_celery_workers()` in `celeryadapter.py` file. This involved: + - Modifying `verify_args()` to return the arguments it verifies/updates + - Changing `launch_celery_worker()` to launch the subprocess (no longer builds the celery command) + - Creating `get_celery_cmd()` to do what `launch_celery_worker()` used to do and build the celery command to run + - Creating `_get_steps_to_start()`, `_create_kwargs()`, and `_get_workers_to_start()` as helper functions to simplify logic in `start_celery_workers()` +- Modified the `merlinspec.json` file: + - the minimum `gpus per task` is now 0 instead of 1 + - variables defined in the `env` block of a spec file can now be arrays ## [1.9.1] ### Fixed diff --git a/CONTRIBUTORS b/CONTRIBUTORS index a920e45b7..2d4a7c9f5 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -6,3 +6,4 @@ Joe Koning Jeremy White Aidan Keogh Ryan Lee +Brian Gunnarson \ No newline at end of file diff --git a/merlin/main.py b/merlin/main.py index 47f471dd8..723fadd22 100644 --- a/merlin/main.py +++ b/merlin/main.py @@ -68,7 +68,7 @@ class HelpParser(ArgumentParser): print the help message when an error happens.""" def error(self, message): - sys.stderr.write("error: %s\n" % message) + sys.stderr.write(f"error: {message}\n") self.print_help() sys.exit(2) @@ -222,7 +222,7 @@ def launch_workers(args): spec, filepath = get_merlin_spec_with_override(args) if not args.worker_echo_only: LOG.info(f"Launching workers from '{filepath}'") - status = router.launch_workers(spec, args.worker_steps, args.worker_args, args.worker_echo_only) + status = router.launch_workers(spec, args.worker_steps, args.worker_args, args.disable_logs, args.worker_echo_only) if args.worker_echo_only: print(status) else: @@ -280,6 +280,8 @@ def stop_workers(args): """ print(banner_small) worker_names = [] + + # Load in the spec if one was provided via the CLI if args.spec: spec_path = verify_filepath(args.spec) spec = MerlinSpec.load_specification(spec_path) @@ -287,6 +289,8 @@ def stop_workers(args): for worker_name in worker_names: if "$" in worker_name: LOG.warning(f"Worker '{worker_name}' is unexpanded. Target provenance spec instead?") + + # Send stop command to router router.stop_workers(args.task_server, worker_names, args.queues, args.workers) @@ -344,6 +348,10 @@ def process_monitor(args): def process_server(args: Namespace): + """ + Route to the correct function based on the command + given via the CLI + """ if args.commands == "init": init_server() elif args.commands == "start": @@ -755,6 +763,12 @@ def generate_worker_touching_parsers(subparsers: ArgumentParser) -> None: help="Specify desired Merlin variable values to override those found in the specification. Space-delimited. " "Example: '--vars LEARN=path/to/new_learn.py EPOCHS=3'", ) + run_workers.add_argument( + "--disable-logs", + action="store_true", + help="Turn off the logs for the celery workers. Note: having the -l flag " + "in your workers' args section will overwrite this flag for that worker.", + ) # merlin query-workers query: ArgumentParser = subparsers.add_parser("query-workers", help="List connected task server workers.") @@ -787,6 +801,8 @@ def generate_worker_touching_parsers(subparsers: ArgumentParser) -> None: stop.add_argument( "--workers", type=str, + action="store", + nargs="+", default=None, help="regex match for specific workers to stop", ) diff --git a/merlin/router.py b/merlin/router.py index 3ec9122ff..ab4b8e933 100644 --- a/merlin/router.py +++ b/merlin/router.py @@ -53,7 +53,7 @@ try: - import importlib.resources as resources + from importlib import resources except ImportError: import importlib_resources as resources @@ -74,7 +74,7 @@ def run_task_server(study, run_mode=None): LOG.error("Celery is not specified as the task server!") -def launch_workers(spec, steps, worker_args="", just_return_command=False): +def launch_workers(spec, steps, worker_args="", disable_logs=False, just_return_command=False): """ Launches workers for the specified study. @@ -83,12 +83,13 @@ def launch_workers(spec, steps, worker_args="", just_return_command=False): :param `worker_args`: Optional arguments for the workers :param `just_return_command`: Don't execute, just return the command """ - if spec.merlin["resources"]["task_server"] == "celery": + if spec.merlin["resources"]["task_server"] == "celery": # pylint: disable=R1705 # Start workers - cproc = start_celery_workers(spec, steps, worker_args, just_return_command) + cproc = start_celery_workers(spec, steps, worker_args, disable_logs, just_return_command) return cproc else: LOG.error("Celery is not specified as the task server!") + return "No workers started" def purge_tasks(task_server, spec, force, steps): @@ -103,12 +104,13 @@ def purge_tasks(task_server, spec, force, steps): """ LOG.info(f"Purging queues for steps = {steps}") - if task_server == "celery": + if task_server == "celery": # pylint: disable=R1705 queues = spec.make_queue_string(steps) # Purge tasks return purge_celery_tasks(queues, force) else: LOG.error("Celery is not specified as the task server!") + return -1 def query_status(task_server, spec, steps, verbose=True): @@ -122,12 +124,13 @@ def query_status(task_server, spec, steps, verbose=True): if verbose: LOG.info(f"Querying queues for steps = {steps}") - if task_server == "celery": + if task_server == "celery": # pylint: disable=R1705 queues = spec.get_queue_list(steps) # Query the queues return query_celery_queues(queues) else: LOG.error("Celery is not specified as the task server!") + return [] def dump_status(query_return, csv_file): @@ -141,7 +144,7 @@ def dump_status(query_return, csv_file): fmode = "a" else: fmode = "w" - with open(csv_file, mode=fmode) as f: + with open(csv_file, mode=fmode) as f: # pylint: disable=W1514,C0103 if f.mode == "w": # add the header f.write("# time") for name, job, consumer in query_return: @@ -162,7 +165,7 @@ def query_workers(task_server): LOG.info("Searching for workers...") if task_server == "celery": - return query_celery_workers() + query_celery_workers() else: LOG.error("Celery is not specified as the task server!") @@ -174,10 +177,11 @@ def get_workers(task_server): :return: A list of all connected workers :rtype: list """ - if task_server == "celery": + if task_server == "celery": # pylint: disable=R1705 return get_workers_from_app() else: LOG.error("Celery is not specified as the task server!") + return [] def stop_workers(task_server, spec_worker_names, queues, workers_regex): @@ -191,14 +195,14 @@ def stop_workers(task_server, spec_worker_names, queues, workers_regex): """ LOG.info("Stopping workers...") - if task_server == "celery": + if task_server == "celery": # pylint: disable=R1705 # Stop workers - return stop_celery_workers(queues, spec_worker_names, workers_regex) + stop_celery_workers(queues, spec_worker_names, workers_regex) else: LOG.error("Celery is not specified as the task server!") -def route_for_task(name, args, kwargs, options, task=None, **kw): +def route_for_task(name, args, kwargs, options, task=None, **kw): # pylint: disable=W0613,R1710 """ Custom task router for queues """ @@ -249,7 +253,7 @@ def check_merlin_status(args, spec): total_jobs = 0 total_consumers = 0 - for name, jobs, consumers in queue_status: + for _, jobs, consumers in queue_status: total_jobs += jobs total_consumers += consumers diff --git a/merlin/spec/merlinspec.json b/merlin/spec/merlinspec.json index 47e738ee6..3044cd506 100644 --- a/merlin/spec/merlinspec.json +++ b/merlin/spec/merlinspec.json @@ -49,7 +49,7 @@ "type": {"type": "string", "minLength": 1} } }, - "gpus per task": {"type": "integer", "minimum": 1}, + "gpus per task": {"type": "integer", "minimum": 0}, "max_retries": {"type": "integer", "minimum": 1}, "task_queue": {"type": "string", "minLength": 1}, "nodes": { @@ -146,7 +146,8 @@ "^.*": { "anyOf": [ {"type": "string", "minLength": 1}, - {"type": "number"} + {"type": "number"}, + {"type": "array"} ] } } diff --git a/merlin/spec/specification.py b/merlin/spec/specification.py index 287fab1f6..65326d54f 100644 --- a/merlin/spec/specification.py +++ b/merlin/spec/specification.py @@ -48,7 +48,7 @@ LOG = logging.getLogger(__name__) -class MerlinSpec(YAMLSpecification): +class MerlinSpec(YAMLSpecification): # pylint: disable=R0902 """ This class represents the logic for parsing the Merlin yaml specification. @@ -67,8 +67,8 @@ class MerlinSpec(YAMLSpecification): column_labels: [X0, X1] """ - def __init__(self): - super(MerlinSpec, self).__init__() + def __init__(self): # pylint: disable=W0246 + super(MerlinSpec, self).__init__() # pylint: disable=R1725 @property def yaml_sections(self): @@ -123,32 +123,50 @@ def __str__(self): return result @classmethod - def load_specification(cls, filepath, suppress_warning=True): + def load_specification(cls, filepath, suppress_warning=True): # pylint: disable=W0237 + """ + Load in a spec file and create a MerlinSpec object based on its' contents. + + :param `cls`: The class reference (like self) + :param `filepath`: A path to the spec file we're loading in + :param `suppress_warning`: A bool representing whether to warn the user about unrecognized keys + :returns: A MerlinSpec object + """ LOG.info("Loading specification from path: %s", filepath) try: # Load the YAML spec from the filepath with open(filepath, "r") as data: spec = cls.load_spec_from_string(data, needs_IO=False, needs_verification=True) - except Exception as e: + except Exception as e: # pylint: disable=C0103 LOG.exception(e.args) raise e # Path not set in _populate_spec because loading spec with string # does not have a path so we set it here spec.path = filepath - spec.specroot = os.path.dirname(spec.path) + spec.specroot = os.path.dirname(spec.path) # pylint: disable=W0201 if not suppress_warning: spec.warn_unrecognized_keys() return spec @classmethod - def load_spec_from_string(cls, string, needs_IO=True, needs_verification=False): + def load_spec_from_string(cls, string, needs_IO=True, needs_verification=False): # pylint: disable=C0103 + """ + Read in a spec file from a string (or stream) and create a MerlinSpec object from it. + + :param `cls`: The class reference (like self) + :param `string`: A string or stream of the file we're reading in + :param `needs_IO`: A bool representing whether we need to turn the string into a file + object or not + :param `needs_verification`: A bool representing whether we need to verify the spec + :returns: A MerlinSpec object + """ LOG.debug("Creating Merlin spec object...") # Create and populate the MerlinSpec object data = StringIO(string) if needs_IO else string spec = cls._populate_spec(data) - spec.specroot = None + spec.specroot = None # pylint: disable=W0201 spec.process_spec_defaults() LOG.debug("Merlin spec object created.") @@ -179,7 +197,7 @@ def _populate_spec(cls, data): try: spec = yaml.load(data, yaml.FullLoader) except AttributeError: - LOG.warn( + LOG.warning( "PyYAML is using an unsafe version with a known " "load vulnerability. Please upgrade your installation " "to a more recent version!" @@ -198,11 +216,11 @@ def _populate_spec(cls, data): # Reset the file pointer and load the merlin block data.seek(0) - merlin_spec.merlin = MerlinSpec.load_merlin_block(data) + merlin_spec.merlin = MerlinSpec.load_merlin_block(data) # pylint: disable=W0201 # Reset the file pointer and load the user block data.seek(0) - merlin_spec.user = MerlinSpec.load_user_block(data) + merlin_spec.user = MerlinSpec.load_user_block(data) # pylint: disable=W0201 return merlin_spec @@ -262,7 +280,7 @@ def _verify_workers(self): ) raise ValueError(error_msg) - except Exception: + except Exception: # pylint: disable=W0706 raise def verify_merlin_block(self, schema): @@ -288,7 +306,7 @@ def verify_batch_block(self, schema): YAMLSpecification.validate_schema("batch", self.batch, schema) # Additional Walltime checks in case the regex from the schema bypasses an error - if "walltime" in self.batch: + if "walltime" in self.batch: # pylint: disable=R1702 if self.batch["type"] == "lsf": LOG.warning("The walltime argument is not available in lsf.") else: @@ -299,17 +317,17 @@ def verify_batch_block(self, schema): # Walltime must have : if it's not of the form SS if ":" not in walltime: raise ValueError(err_msg) - else: - # Walltime must have exactly 2 chars between : - time = walltime.split(":") - for section in time: - if len(section) != 2: - raise ValueError(err_msg) - except Exception: + # Walltime must have exactly 2 chars between : + time = walltime.split(":") + for section in time: + if len(section) != 2: + raise ValueError(err_msg) + except Exception: # pylint: disable=W0706 raise @staticmethod def load_merlin_block(stream): + """Loads in the merlin block of the spec file""" try: merlin_block = yaml.safe_load(stream)["merlin"] except KeyError: @@ -324,6 +342,7 @@ def load_merlin_block(stream): @staticmethod def load_user_block(stream): + """Loads in the user block of the spec file""" try: user_block = yaml.safe_load(stream)["user"] except KeyError: @@ -331,6 +350,7 @@ def load_user_block(stream): return user_block def process_spec_defaults(self): + """Fills in the default values if they aren't there already""" for name, section in self.sections.items(): if section is None: setattr(self, name, {}) @@ -354,8 +374,25 @@ def process_spec_defaults(self): if self.merlin["resources"]["workers"] is None: self.merlin["resources"]["workers"] = {"default_worker": defaults.WORKER} else: - for worker, vals in self.merlin["resources"]["workers"].items(): - MerlinSpec.fill_missing_defaults(vals, defaults.WORKER) + # Gather a list of step names defined in the study + all_workflow_steps = self.get_study_step_names() + # Create a variable to track the steps assigned to workers + worker_steps = [] + + # Loop through each worker and fill in the defaults + for _, worker_settings in self.merlin["resources"]["workers"].items(): + MerlinSpec.fill_missing_defaults(worker_settings, defaults.WORKER) + worker_steps.extend(worker_settings["steps"]) + + # Figure out which steps still need workers + steps_that_need_workers = list(set(all_workflow_steps) - set(worker_steps)) + + # If there are still steps remaining that haven't been assigned a worker yet, + # assign the remaining steps to the default worker. If all the steps still need workers + # (i.e. no workers were assigned) then default workers' steps should be "all" so we skip this + if steps_that_need_workers and (steps_that_need_workers != all_workflow_steps): + self.merlin["resources"]["workers"]["default_worker"] = defaults.WORKER + self.merlin["resources"]["workers"]["default_worker"]["steps"] = steps_that_need_workers if self.merlin["samples"] is not None: MerlinSpec.fill_missing_defaults(self.merlin["samples"], defaults.SAMPLES) @@ -370,7 +407,7 @@ def fill_missing_defaults(object_to_update, default_dict): existing ones. """ - def recurse(result, defaults): + def recurse(result, defaults): # pylint: disable=W0621 if not isinstance(defaults, dict): return for key, val in defaults.items(): @@ -387,6 +424,7 @@ def recurse(result, defaults): # ***Unsure if this method is still needed after adding json schema verification*** def warn_unrecognized_keys(self): + """Checks if there are any unrecognized keys in the spec file""" # check description MerlinSpec.check_section("description", self.description, all_keys.DESCRIPTION) @@ -397,7 +435,7 @@ def warn_unrecognized_keys(self): MerlinSpec.check_section("env", self.environment, all_keys.ENV) # check parameters - for param, contents in self.globals.items(): + for _, contents in self.globals.items(): MerlinSpec.check_section("global.parameters", contents, all_keys.PARAMETER) # check steps @@ -416,13 +454,14 @@ def warn_unrecognized_keys(self): # user block is not checked @staticmethod - def check_section(section_name, section, all_keys): + def check_section(section_name, section, all_keys): # pylint: disable=W0621 + """Checks a section of the spec file to see if there are any unrecognized keys""" diff = set(section.keys()).difference(all_keys) # TODO: Maybe add a check here for required keys for extra in diff: - LOG.warn(f"Unrecognized key '{extra}' found in spec section '{section_name}'.") + LOG.warning(f"Unrecognized key '{extra}' found in spec section '{section_name}'.") def dump(self): """ @@ -434,11 +473,11 @@ def dump(self): result = result.replace("\n\n\n", "\n\n") try: yaml.safe_load(result) - except Exception as e: - raise ValueError(f"Error parsing provenance spec:\n{e}") + except Exception as e: # pylint: disable=C0103 + raise ValueError(f"Error parsing provenance spec:\n{e}") # pylint: disable=W0707 return result - def _dict_to_yaml(self, obj, string, key_stack, tab, newline=True): + def _dict_to_yaml(self, obj, string, key_stack, tab): """ The if-else ladder for sorting the yaml string prettification of dump(). """ @@ -449,12 +488,11 @@ def _dict_to_yaml(self, obj, string, key_stack, tab, newline=True): if isinstance(obj, str): return self._process_string(obj, lvl, tab) - elif isinstance(obj, bool): + if isinstance(obj, bool): return str(obj).lower() - elif not (isinstance(obj, list) or isinstance(obj, dict)): + if not isinstance(obj, (list, dict)): return obj - else: - return self._process_dict_or_list(obj, string, key_stack, lvl, tab) + return self._process_dict_or_list(obj, string, key_stack, lvl, tab) def _process_string(self, obj, lvl, tab): """ @@ -465,15 +503,15 @@ def _process_string(self, obj, lvl, tab): obj = "|\n" + tab * (lvl + 1) + ("\n" + tab * (lvl + 1)).join(split) return obj - def _process_dict_or_list(self, obj, string, key_stack, lvl, tab): + def _process_dict_or_list(self, obj, string, key_stack, lvl, tab): # pylint: disable=R0912,R0913 """ Processes lists and dicts for _dict_to_yaml() in the dump() method. """ - from copy import deepcopy + from copy import deepcopy # pylint: disable=C0415 list_offset = 2 * " " if isinstance(obj, list): - n = len(obj) + num_entries = len(obj) use_hyphens = key_stack[-1] in ["paths", "sources", "git", "study"] or key_stack[0] in ["user"] if not use_hyphens: string += "[" @@ -485,8 +523,8 @@ def _process_dict_or_list(self, obj, string, key_stack, lvl, tab): if use_hyphens: string += (lvl + 1) * tab + "- " + str(self._dict_to_yaml(elem, "", key_stack, tab)) + "\n" else: - string += str(self._dict_to_yaml(elem, "", key_stack, tab, newline=(i != 0))) - if n > 1 and i != len(obj) - 1: + string += str(self._dict_to_yaml(elem, "", key_stack, tab)) + if num_entries > 1 and i != len(obj) - 1: string += ", " key_stack.pop() if not use_hyphens: @@ -496,9 +534,9 @@ def _process_dict_or_list(self, obj, string, key_stack, lvl, tab): if len(key_stack) > 0 and key_stack[-1] != "elem": string += "\n" i = 0 - for k, v in obj.items(): + for key, val in obj.items(): key_stack = deepcopy(key_stack) - key_stack.append(k) + key_stack.append(key) if len(key_stack) > 1 and key_stack[-2] == "elem" and i == 0: # string += (tab * (lvl - 1)) string += "" @@ -506,14 +544,32 @@ def _process_dict_or_list(self, obj, string, key_stack, lvl, tab): string += list_offset + (tab * lvl) else: string += tab * (lvl + 1) - string += str(k) + ": " + str(self._dict_to_yaml(v, "", key_stack, tab)) + "\n" + string += str(key) + ": " + str(self._dict_to_yaml(val, "", key_stack, tab)) + "\n" key_stack.pop() i += 1 return string + def get_step_worker_map(self): + """ + Creates a dictionary with step names as keys and a list of workers + associated with each step as values. The inverse of get_worker_step_map(). + """ + steps = self.get_study_step_names() + step_worker_map = {step_name: [] for step_name in steps} + for worker_name, worker_val in self.merlin["resources"]["workers"].items(): + # Case 1: worker doesn't have specific steps + if "all" in worker_val["steps"]: + for step_name in step_worker_map: + step_worker_map[step_name].append(worker_name) + # Case 2: worker has specific steps + else: + for step in worker_val["steps"]: + step_worker_map[step].append(worker_name) + return step_worker_map + def get_task_queues(self): """Returns a dictionary of steps and their corresponding task queues.""" - from merlin.config.configfile import CONFIG + from merlin.config.configfile import CONFIG # pylint: disable=C0415 steps = self.get_study_steps() queues = {} @@ -540,8 +596,8 @@ def get_queue_list(self, steps): else: task_queues = [queues[steps]] except KeyError: - nl = "\n" - LOG.error(f"Invalid steps '{steps}'! Try one of these (or 'all'):\n{nl.join(queues.keys())}") + newline = "\n" + LOG.error(f"Invalid steps '{steps}'! Try one of these (or 'all'):\n{newline.join(queues.keys())}") raise return sorted(set(task_queues)) @@ -555,6 +611,7 @@ def make_queue_string(self, steps): return shlex.quote(queues) def get_worker_names(self): + """Builds a list of workers""" result = [] for worker in self.merlin["resources"]["workers"]: result.append(worker) diff --git a/merlin/study/celeryadapter.py b/merlin/study/celeryadapter.py index f6d344a25..81f0762f8 100644 --- a/merlin/study/celeryadapter.py +++ b/merlin/study/celeryadapter.py @@ -51,8 +51,8 @@ def run_celery(study, run_mode=None): configure Celery to run locally (without workers). """ # Only import celery stuff if we want celery in charge - from merlin.celery import app - from merlin.common.tasks import queue_merlin_study + from merlin.celery import app # pylint: disable=C0415 + from merlin.common.tasks import queue_merlin_study # pylint: disable=C0415 adapter_config = study.get_adapter_config(override_type="local") @@ -145,7 +145,7 @@ def query_celery_queues(queues): Send results to the log. """ - from merlin.celery import app + from merlin.celery import app # pylint: disable=C0415 connection = app.connection() found_queues = [] @@ -155,7 +155,7 @@ def query_celery_queues(queues): try: name, jobs, consumers = channel.queue_declare(queue=queue, passive=True) found_queues.append((name, jobs, consumers)) - except Exception as e: + except Exception as e: # pylint: disable=C0103,W0718 LOG.warning(f"Cannot find queue {queue} on server.{e}") finally: connection.close() @@ -169,7 +169,7 @@ def get_workers_from_app(): :return: A list of all connected workers :rtype: list """ - from merlin.celery import app + from merlin.celery import app # pylint: disable=C0415 i = app.control.inspect() workers = i.ping() @@ -178,10 +178,90 @@ def get_workers_from_app(): return [*workers] -def start_celery_workers(spec, steps, celery_args, just_return_command): +def _get_workers_to_start(spec, steps): + """ + Helper function to return a set of workers to start based on + the steps provided by the user. + + :param `spec`: A MerlinSpec object + :param `steps`: A list of steps to start workers for + + :returns: A set of workers to start + """ + workers_to_start = [] + step_worker_map = spec.get_step_worker_map() + for step in steps: + try: + workers_to_start.extend(step_worker_map[step]) + except KeyError: + LOG.warning(f"Cannot start workers for step: {step}. This step was not found.") + + workers_to_start = set(workers_to_start) + LOG.debug(f"workers_to_start: {workers_to_start}") + + return workers_to_start + + +def _create_kwargs(spec): + """ + Helper function to handle creating the kwargs dict that + we'll pass to subprocess.Popen when we launch the worker. + + :param `spec`: A MerlinSpec object + :returns: A tuple where the first entry is the kwargs and + the second entry is variables defined in the spec + """ + # Get the environment from the spec and the shell + spec_env = spec.environment + shell_env = os.environ.copy() + yaml_vars = None + + # If the environment from the spec has anything in it, + # read in the variables and save them to the shell environment + if spec_env: + yaml_vars = get_yaml_var(spec_env, "variables", {}) + for var_name, var_val in yaml_vars.items(): + shell_env[str(var_name)] = str(var_val) + # For expandvars + os.environ[str(var_name)] = str(var_val) + + # Create the kwargs dict + kwargs = {"env": shell_env, "shell": True, "universal_newlines": True} + return kwargs, yaml_vars + + +def _get_steps_to_start(wsteps, steps, steps_provided): + """ + Determine which steps to start workers for. + + :param `wsteps`: A list of steps associated with a worker + :param `steps`: A list of steps to start provided by the user + :param `steps`: A bool representing whether the user gave specific + steps to start or not + :returns: A list of steps to start workers for + """ + steps_to_start = [] + if steps_provided: + for wstep in wsteps: + if wstep in steps: + steps_to_start.append(wstep) + else: + steps_to_start.extend(wsteps) + + return steps_to_start + + +def start_celery_workers(spec, steps, celery_args, disable_logs, just_return_command): # pylint: disable=R0914,R0915 """Start the celery workers on the allocation - specs Tuple of (YAMLSpecification, MerlinSpec) + :param MerlinSpec spec: A MerlinSpec object representing our study + :param list steps: A list of steps to start workers for + :param str celery_args: A string of arguments to provide to the celery workers + :param bool disable_logs: A boolean flag to turn off the celery logs for the workers + :param bool just_return_command: When True, workers aren't started and just the launch command(s) + are returned + :side effect: Starts subprocesses for each worker we launch + :returns: A string of all the worker launch commands ... example config: @@ -203,20 +283,22 @@ def start_celery_workers(spec, steps, celery_args, just_return_command): overlap = spec.merlin["resources"]["overlap"] workers = spec.merlin["resources"]["workers"] - senv = spec.environment - spenv = os.environ.copy() - yenv = None - if senv: - yenv = get_yaml_var(senv, "variables", {}) - for k, v in yenv.items(): - spenv[str(k)] = str(v) - # For expandvars - os.environ[str(k)] = str(v) + # Build kwargs dict for subprocess.Popen to use when we launch the worker + kwargs, yenv = _create_kwargs(spec) worker_list = [] local_queues = [] + # Get the workers we need to start if we're only starting certain steps + steps_provided = False if "all" in steps else True # pylint: disable=R1719 + if steps_provided: + workers_to_start = _get_workers_to_start(spec, steps) + for worker_name, worker_val in workers.items(): + # Only triggered if --steps flag provided + if steps_provided and worker_name not in workers_to_start: + continue + skip_loop_step: bool = examine_and_log_machines(worker_val, yenv) if skip_loop_step: continue @@ -227,73 +309,60 @@ def start_celery_workers(spec, steps, celery_args, just_return_command): worker_args = "" worker_nodes = get_yaml_var(worker_val, "nodes", None) - worker_batch = get_yaml_var(worker_val, "batch", None) + # Get the correct steps to start workers for wsteps = get_yaml_var(worker_val, "steps", steps) - queues = spec.make_queue_string(wsteps).split(",") + steps_to_start = _get_steps_to_start(wsteps, steps, steps_provided) + queues = spec.make_queue_string(steps_to_start) # Check for missing arguments - verify_args(spec, worker_args, worker_name, overlap) + worker_args = verify_args(spec, worker_args, worker_name, overlap, disable_logs=disable_logs) # Add a per worker log file (debug) if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Redirecting worker output to individual log files") worker_args += " --logfile %p.%i" - # Get the celery command - celery_com = launch_celery_workers(spec, steps=wsteps, worker_args=worker_args, just_return_command=True) - + # Get the celery command & add it to the batch launch command + celery_com = get_celery_cmd(queues, worker_args=worker_args, just_return_command=True) celery_cmd = os.path.expandvars(celery_com) - worker_cmd = batch_worker_launch(spec, celery_cmd, nodes=worker_nodes, batch=worker_batch) - worker_cmd = os.path.expandvars(worker_cmd) - try: - kwargs = {"env": spenv, "shell": True, "universal_newlines": True} - # These cannot be used with a detached process - # "stdout": subprocess.PIPE, - # "stderr": subprocess.PIPE, - - LOG.debug(f"worker cmd={worker_cmd}") - LOG.debug(f"env={spenv}") - - if just_return_command: - worker_list = "" - print(worker_cmd) - continue - - found = [] - running_queues = [] - - running_queues.extend(local_queues) - if not overlap: - running_queues.extend(get_running_queues()) - # Cache the queues from this worker to use to test - # for existing queues in any subsequent workers. - # If overlap is True, then do not check the local queues. - # This will allow multiple workers to pull from the same - # queue. - local_queues.extend(queues) - - for q in queues: - if q in running_queues: - found.append(q) - - if found: - LOG.warning( - f"A celery worker named '{worker_name}' is already configured/running for queue(s) = {' '.join(found)}" - ) - continue + LOG.debug(f"worker cmd={worker_cmd}") - _ = subprocess.Popen(worker_cmd, **kwargs) + if just_return_command: + worker_list = "" + print(worker_cmd) + continue - worker_list.append(worker_cmd) + # Get the running queues + running_queues = [] + running_queues.extend(local_queues) + queues = queues.split(",") + if not overlap: + running_queues.extend(get_running_queues()) + # Cache the queues from this worker to use to test + # for existing queues in any subsequent workers. + # If overlap is True, then do not check the local queues. + # This will allow multiple workers to pull from the same + # queue. + local_queues.extend(queues) + + # Search for already existing queues and log a warning if we try to start one that already exists + found = [] + for q in queues: # pylint: disable=C0103 + if q in running_queues: + found.append(q) + if found: + LOG.warning( + f"A celery worker named '{worker_name}' is already configured/running for queue(s) = {' '.join(found)}" + ) + continue - except Exception as e: - LOG.error(f"Cannot start celery workers, {e}") - raise + # Start the worker + launch_celery_worker(worker_cmd, worker_list, kwargs) # Return a string with the worker commands for logging return str(worker_list) @@ -306,7 +375,7 @@ def examine_and_log_machines(worker_val, yenv) -> bool: """ worker_machines = get_yaml_var(worker_val, "machines", None) if worker_machines: - LOG.debug("check machines = ", check_machines(worker_machines)) + LOG.debug(f"check machines = {check_machines(worker_machines)}") if not check_machines(worker_machines): return True @@ -320,11 +389,10 @@ def examine_and_log_machines(worker_val, yenv) -> bool: "The env:variables section does not have an OUTPUT_PATH specified, multi-machine checks cannot be performed." ) return False - else: - return False + return False -def verify_args(spec, worker_args, worker_name, overlap): +def verify_args(spec, worker_args, worker_name, overlap, disable_logs=False): """Examines the args passed to a worker for completeness.""" parallel = batch_check_parallel(spec) if parallel: @@ -340,29 +408,46 @@ def verify_args(spec, worker_args, worker_name, overlap): if overlap: nhash = time.strftime("%Y%m%d-%H%M%S") # TODO: Once flux fixes their bug, change this back to %h + # %h in Celery is short for hostname including domain name worker_args += f" -n {worker_name}{nhash}.%%h" - if "-l" not in worker_args: + if not disable_logs and "-l" not in worker_args: worker_args += f" -l {logging.getLevelName(LOG.getEffectiveLevel())}" + return worker_args + + +def launch_celery_worker(worker_cmd, worker_list, kwargs): + """ + Using the worker launch command provided, launch a celery worker. + :param str worker_cmd: The celery command to launch a worker + :param list worker_list: A list of worker launch commands + :param dict kwargs: A dictionary containing additional keyword args to provide + to subprocess.Popen -def launch_celery_workers(spec, steps=None, worker_args="", just_return_command=False): + :side effect: Launches a celery worker via a subprocess """ - Launch celery workers for the specified MerlinStudy. + try: + _ = subprocess.Popen(worker_cmd, **kwargs) # pylint: disable=R1732 + worker_list.append(worker_cmd) + except Exception as e: # pylint: disable=C0103 + LOG.error(f"Cannot start celery workers, {e}") + raise + - spec MerlinSpec object - steps The steps in the spec to tie the workers to +def get_celery_cmd(queue_names, worker_args="", just_return_command=False): + """ + Get the appropriate command to launch celery workers for the specified MerlinStudy. + queue_names The name(s) of the queue(s) to associate a worker with worker_args Optional celery arguments for the workers just_return_command Don't execute, just return the command """ - queues = spec.make_queue_string(steps) - worker_command = " ".join(["celery -A merlin worker", worker_args, "-Q", queues]) + worker_command = " ".join(["celery -A merlin worker", worker_args, "-Q", queue_names]) if just_return_command: return worker_command - else: - # This only runs celery locally the user would need to - # add all of the flux config themselves. - pass + # If we get down here, this only runs celery locally the user would need to + # add all of the flux config themselves. + return "" def purge_celery_tasks(queues, force): @@ -402,7 +487,8 @@ def stop_celery_workers(queues=None, spec_worker_names=None, worker_regex=None): >>> stop_celery_workers() """ - from merlin.celery import app + from merlin.celery import app # pylint: disable=C0415 + from merlin.config.configfile import CONFIG # pylint: disable=C0415 LOG.debug(f"Sending stop to queues: {queues}, worker_regex: {worker_regex}, spec_worker_names: {spec_worker_names}") active_queues, _ = get_queues(app) @@ -410,6 +496,10 @@ def stop_celery_workers(queues=None, spec_worker_names=None, worker_regex=None): # If not specified, get all the queues if queues is None: queues = [*active_queues] + # Celery adds the queue tag in front of each queue so we add that here + else: + for i, queue in enumerate(queues): + queues[i] = f"{CONFIG.celery.queue_tag}{queue}" # Find the set of all workers attached to all of those queues all_workers = set() @@ -424,23 +514,31 @@ def stop_celery_workers(queues=None, spec_worker_names=None, worker_regex=None): LOG.debug(f"Pre-filter worker stop list: {all_workers}") - print(f"all_workers: {all_workers}") - print(f"spec_worker_names: {spec_worker_names}") + # Stop workers with no flags if (spec_worker_names is None or len(spec_worker_names) == 0) and worker_regex is None: workers_to_stop = list(all_workers) + # Flag handling else: workers_to_stop = [] + # --spec flag if (spec_worker_names is not None) and len(spec_worker_names) > 0: for worker_name in spec_worker_names: - print(f"Result of regex_list_filter: {regex_list_filter(worker_name, all_workers)}") + LOG.debug( + f"""Result of regex_list_filter for {worker_name}: + {regex_list_filter(worker_name, all_workers, match=False)}""" + ) workers_to_stop += regex_list_filter(worker_name, all_workers, match=False) + # --workers flag if worker_regex is not None: - workers_to_stop += regex_list_filter(worker_regex, workers_to_stop) + for worker in worker_regex: + LOG.debug(f"Result of regex_list_filter: {regex_list_filter(worker, all_workers, match=False)}") + workers_to_stop += regex_list_filter(worker, all_workers, match=False) - print(f"workers_to_stop: {workers_to_stop}") + # Remove duplicates + workers_to_stop = list(set(workers_to_stop)) if workers_to_stop: LOG.info(f"Sending stop to these workers: {workers_to_stop}") - return app.control.broadcast("shutdown", destination=workers_to_stop) + app.control.broadcast("shutdown", destination=workers_to_stop) else: LOG.warning("No workers found to stop") @@ -454,10 +552,10 @@ def create_celery_config(config_dir, data_file_name, data_file_path): :param `data_file_path`: The full data file path. """ # This will need to come from the server interface - MERLIN_CONFIG = os.path.join(config_dir, data_file_name) + MERLIN_CONFIG = os.path.join(config_dir, data_file_name) # pylint: disable=C0103 if os.path.isfile(MERLIN_CONFIG): - from merlin.common.security import encrypt + from merlin.common.security import encrypt # pylint: disable=C0415 encrypt.init_key() LOG.info(f"The config file already exists, {MERLIN_CONFIG}") @@ -471,6 +569,6 @@ def create_celery_config(config_dir, data_file_name, data_file_path): LOG.info(f"The file {MERLIN_CONFIG} is ready to be edited for your system.") - from merlin.common.security import encrypt + from merlin.common.security import encrypt # pylint: disable=C0415 encrypt.init_key() diff --git a/tests/integration/test_definitions.py b/tests/integration/test_definitions.py index 752b3650d..aafbabe2f 100644 --- a/tests/integration/test_definitions.py +++ b/tests/integration/test_definitions.py @@ -56,8 +56,7 @@ OUTPUT_DIR = "cli_test_studies" CLEAN_MERLIN_SERVER = "rm -rf appendonly.aof dump.rdb merlin_server/" -# KILL_WORKERS = "pkill -9 -f '.*merlin_test_worker'" -KILL_WORKERS = "pkill -9 -f 'celery'" +KILL_WORKERS = "pkill -9 -f '.*merlin_test_worker'" def define_tests(): # pylint: disable=R0914