From 2b8ba1d8ef1595e5ffabe981cc8b37faad0e3dd9 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 11:24:28 +0200 Subject: [PATCH 01/12] remove (object) from classes --- src/amuse/rfi/core.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/amuse/rfi/core.py b/src/amuse/rfi/core.py index 1a64358aee..773eeb478b 100644 --- a/src/amuse/rfi/core.py +++ b/src/amuse/rfi/core.py @@ -7,7 +7,6 @@ import pydoc import traceback import random -import sys import warnings import inspect @@ -34,7 +33,7 @@ from amuse import config except ImportError as ex: - class config(object): + class config: is_mpi_enabled = False @@ -84,7 +83,7 @@ def _typecode_to_datatype(typecode): raise exceptions.AmuseException("{0} is not a valid typecode".format(typecode)) -class CodeFunction(object): +class CodeFunction: __doc__ = CodeDocStringProperty() @@ -318,7 +317,7 @@ def __str__(self): return str(self.specification) -class legacy_function(object): +class legacy_function: __doc__ = CodeDocStringProperty() @@ -333,7 +332,7 @@ def __init__(self, specification_function): a LegacyFunctionSpecification. - >>> class LegacyExample(object): + >>> class LegacyExample: ... @legacy_function ... def evolve(): ... specification = LegacyFunctionSpecification() @@ -573,7 +572,7 @@ def remote_function(f=None, must_handle_array=False, can_handle_array=False): ) -class ParameterSpecification(object): +class ParameterSpecification: def __init__(self, name, dtype, direction, description, default=None, unit=None): """Specification of a parameter of a legacy function""" self.name = name @@ -601,7 +600,7 @@ def has_default_value(self): return not self.default is None -class LegacyFunctionSpecification(object): +class LegacyFunctionSpecification: """ Specification of a legacy function. Describes the name, result type and parameters of a @@ -1185,7 +1184,7 @@ def get_working_directory(): return function -class CodeWithDataDirectories(object): +class CodeWithDataDirectories: def __init__(self): if self.channel_type == "distributed": From 0d4172511a71b5d0daf16dabf1a052ba221bdc68 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 11:29:28 +0200 Subject: [PATCH 02/12] improve pep8 compliance --- src/amuse/rfi/core.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/amuse/rfi/core.py b/src/amuse/rfi/core.py index 773eeb478b..7899e247f7 100644 --- a/src/amuse/rfi/core.py +++ b/src/amuse/rfi/core.py @@ -1,3 +1,5 @@ +from amuse.rfi.channel import LocalChannel +import numpy import weakref import atexit import errno @@ -47,10 +49,6 @@ class config: class for all community codes. """ -import numpy - -from amuse.rfi.channel import LocalChannel - def ensure_mpd_is_running(): from mpi4py import MPI @@ -529,7 +527,10 @@ def wrapper(f): def returns(**kwargs): start = flatsrc.find("returns(") - order = lambda k: flatsrc.find(k[0] + "=", start) + + def order(k): + return flatsrc.find(k[0] + "=", start) + out_arg.extend(sorted(kwargs.items(), key=order)) f.__globals__["returns"] = returns @@ -1365,7 +1366,7 @@ def __call__(self, *arguments_list, **keyword_arguments): handle_as_array = self.must_handle_as_array(dtype_to_values) - if not self.owner is None: + if self.owner is not None: CODE_LOG.info( "start call '%s.%s'", self.owner.__name__, self.specification.name ) @@ -1397,7 +1398,7 @@ def __call__(self, *arguments_list, **keyword_arguments): output_units = self.convert_floats_to_units(output_encoded_units) result = self.converted_results(dtype_to_result, handle_as_array, output_units) - if not self.owner is None: + if self.owner is not None: CODE_LOG.info( "end call '%s.%s'", self.owner.__name__, self.specification.name ) @@ -1485,7 +1486,7 @@ def converted_results(self, dtype_to_result, must_handle_as_array, units): for key, value in dtype_to_result.items(): dtype_to_array[key] = list(reversed(value)) - if not result_type is None: + if result_type is not None: return_value = dtype_to_array[result_type].pop() for parameter in self.specification.output_parameters: @@ -1498,7 +1499,7 @@ def converted_results(self, dtype_to_result, must_handle_as_array, units): result[parameter.name] | units[parameter.index_in_output] ) - if not result_type is None: + if result_type is not None: result["__result"] = return_value return result From a5094d99bc7287e9f4896c930d8629527f64f438 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 11:32:23 +0200 Subject: [PATCH 03/12] if not ... is None -> if ... is not None (etc) --- src/amuse/rfi/core.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/amuse/rfi/core.py b/src/amuse/rfi/core.py index 7899e247f7..7ae521df7e 100644 --- a/src/amuse/rfi/core.py +++ b/src/amuse/rfi/core.py @@ -111,7 +111,7 @@ def __call__(self, *arguments_list, **keyword_arguments): handle_as_array = self.must_handle_as_array(dtype_to_values) - if not self.owner is None: + if self.owner is not None: CODE_LOG.info( "start call '%s.%s'", self.owner.__name__, self.specification.name ) @@ -140,7 +140,7 @@ def __call__(self, *arguments_list, **keyword_arguments): result = self.converted_results(dtype_to_result, handle_as_array) - if not self.owner is None: + if self.owner is not None: CODE_LOG.info( "end call '%s.%s'", self.owner.__name__, self.specification.name ) @@ -227,7 +227,7 @@ def result_index(self): index = [] for parameter in self.specification.output_parameters: index.append(parameter.name) - if not self.specification.result_type is None: + if self.specification.result_type is not None: index.append("__result") return index @@ -259,13 +259,13 @@ def converted_results(self, dtype_to_result, must_handle_as_array): for key, value in dtype_to_result.items(): dtype_to_array[key] = list(reversed(value)) - if not result_type is None: + if result_type is not None: return_value = dtype_to_array[result_type].pop() for parameter in self.specification.output_parameters: result[parameter.name] = dtype_to_array[parameter.datatype].pop() - if not result_type is None: + if result_type is not None: result["__result"] = return_value return result @@ -770,7 +770,7 @@ def __str__(self): p + typecode_to_name[x.datatype] p + " " p + x.name - if not self.result_type is None: + if self.result_type is not None: p + ", " p + typecode_to_name[self.result_type] p + " " @@ -800,7 +800,7 @@ def stop_interfaces(exceptions=[]): """ for reference in reversed(CodeInterface.instances): x = reference() - if not x is None and x.__class__.__name__ not in exceptions: + if x is not None and x.__class__.__name__ not in exceptions: try: x._stop() except: @@ -928,7 +928,7 @@ def ensure_stop_interface_at_exit(cls): @classmethod def retrieve_reusable_channel(cls): - if not "REUSE_INSTANCE" in cls.__dict__: + if "REUSE_INSTANCE" not in cls.__dict__: cls.REUSE_INSTANCE = set([]) s = cls.REUSE_INSTANCE if len(s) > 0: @@ -938,7 +938,7 @@ def retrieve_reusable_channel(cls): @classmethod def store_reusable_channel(cls, instance): - if not "REUSE_INSTANCE" in cls.__dict__: + if "REUSE_INSTANCE" not in cls.__dict__: cls.REUSE_INSTANCE = set([]) s = cls.REUSE_INSTANCE s.add(instance) @@ -946,7 +946,7 @@ def store_reusable_channel(cls, instance): @classmethod def stop_reusable_channels(cls): - if not "REUSE_INSTANCE" in cls.__dict__: + if "REUSE_INSTANCE" not in cls.__dict__: cls.REUSE_INSTANCE = set([]) s = cls.REUSE_INSTANCE while len(s) > 0: @@ -960,7 +960,7 @@ def stop_reusable_channels(cls): def _stop(self): if hasattr(self, "channel"): - if not self.channel is None and self.channel.is_active(): + if self.channel is not None and self.channel.is_active(): if self.reuse_worker: self.store_reusable_channel(self.channel) self.channel = None @@ -1493,7 +1493,7 @@ def converted_results(self, dtype_to_result, must_handle_as_array, units): result[parameter.name] = dtype_to_array[parameter.datatype].pop() if ( self.specification.has_units - and not units[parameter.index_in_output] is None + and units[parameter.index_in_output] is not None ): result[parameter.name] = ( result[parameter.name] | units[parameter.index_in_output] From 57ad8af26236c4b581244704654bfd33bfd11e98 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 11:39:15 +0200 Subject: [PATCH 04/12] reindent some text --- src/amuse/rfi/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/amuse/rfi/core.py b/src/amuse/rfi/core.py index 7ae521df7e..58c0cc16ea 100644 --- a/src/amuse/rfi/core.py +++ b/src/amuse/rfi/core.py @@ -320,14 +320,14 @@ class legacy_function: __doc__ = CodeDocStringProperty() def __init__(self, specification_function): - """Decorator for legacy functions. + """ + Decorator for legacy functions. The decorated function cannot have any arguments. This means the decorated function must not have a ``self`` argument. - The decorated function must return - a LegacyFunctionSpecification. + The decorated function must return a LegacyFunctionSpecification. >>> class LegacyExample: From dd249af6ac1e8140d810d79169f30ec1e088b978 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 11:39:52 +0200 Subject: [PATCH 05/12] simplify crc32 and reduce duplicate lines --- src/amuse/rfi/core.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/amuse/rfi/core.py b/src/amuse/rfi/core.py index 58c0cc16ea..15f170f7ac 100644 --- a/src/amuse/rfi/core.py +++ b/src/amuse/rfi/core.py @@ -386,32 +386,21 @@ def is_compiled_file_up_to_date(self, time_of_the_compiled_file): def crc32(self): try: from zlib import crc32 + except ImportError: + try: + from binascii import crc32 + except ImportError: + raise Exception("No working crc32 implementation found!") - # python 3, crc32 needs bytes... - - def python3_crc32(x): - x = crc32(bytes(x, "ascii")) - return x - ((x & 0x80000000) << 1) - - if python3_crc32("amuse") & 0xFFFFFFFF == 0xC0CC9367: - return python3_crc32 - except Exception: - pass - try: - from binascii import crc32 - - # python 3, crc32 needs bytes... + # python 3, crc32 needs bytes... - def python3_crc32(x): - x = crc32(bytes(x, "ascii")) - return x - ((x & 0x80000000) << 1) + def python3_crc32(x): + x = crc32(bytes(x, "ascii")) + return x - ((x & 0x80000000) << 1) - if python3_crc32("amuse") & 0xFFFFFFFF == 0xC0CC9367: - return python3_crc32 - except Exception: - pass + if python3_crc32("amuse") & 0xFFFFFFFF == 0xC0CC9367: + return python3_crc32 - raise Exception("No working crc32 implementation found!") def derive_dtype_unit_and_default(value): From 647a5f45a248fe0fffc28ee9b8d607b7b5152491 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 15:28:09 +0200 Subject: [PATCH 06/12] replace '#~' comments with '#' comments --- src/amuse/rfi/async_request.py | 20 +- src/amuse/rfi/channel.py | 2293 +++++++++++++++++++------------- src/amuse/rfi/gencode.py | 14 +- src/amuse/rfi/python_code.py | 10 +- 4 files changed, 1383 insertions(+), 954 deletions(-) diff --git a/src/amuse/rfi/async_request.py b/src/amuse/rfi/async_request.py index 707a7447fe..cde5a07ce1 100644 --- a/src/amuse/rfi/async_request.py +++ b/src/amuse/rfi/async_request.py @@ -58,8 +58,8 @@ def get_mpi_request(self): def get_socket(self): raise Exception("not implemented") - #~ def is_pool(self): - #~ return False + # def is_pool(self): + # return False def join(self, other): if other is None: @@ -82,8 +82,8 @@ def waits_for(self): def __getitem__(self, index): return IndexedASyncRequest(self,index) - #~ def __getattr__(self, name): - #~ print name, "<<" + # def __getattr__(self, name): + # print name, "<<" def __add__(self, other): return baseOperatorASyncRequest(self,other, operator.add) @@ -127,8 +127,8 @@ def __iter__(self): else: yield self - #~ def __call__(self): - #~ return self.result() + # def __call__(self): + # return self.result() class DependentASyncRequest(AbstractASyncRequest): def __init__(self, parent, request_factory): @@ -183,12 +183,12 @@ def is_result_available(self): if self.request is None: return False - #~ if not self.parent.is_finished: - #~ return False + # if not self.parent.is_finished: + # return False if self.request is None: return False - #~ raise Exception("something went wrong (exception of parent?)") + # raise Exception("something went wrong (exception of parent?)") return self.request.is_result_available() @@ -607,7 +607,7 @@ def add_request(self, async_request, result_handler = None, args=(), kwargs={}): return if async_request in self.registered_requests: return - #~ raise Exception("Request is already registered, cannot register a request more than once") + # raise Exception("Request is already registered, cannot register a request more than once") self.registered_requests.add(async_request) diff --git a/src/amuse/rfi/channel.py b/src/amuse/rfi/channel.py index d47c092e98..0486837d7b 100644 --- a/src/amuse/rfi/channel.py +++ b/src/amuse/rfi/channel.py @@ -24,14 +24,14 @@ # so actual import is in function ensure_mpi_initialized # MPI = None - + from subprocess import Popen, PIPE try: from amuse import config except ImportError: config = None - + from amuse.support.options import OptionalAttributes, option, GlobalOptions from amuse.support.core import late from amuse.support import exceptions @@ -43,26 +43,31 @@ from . import async_request + class AbstractMessage(object): - - def __init__(self, - call_id=0, function_id=-1, call_count=1, + + def __init__( + self, + call_id=0, + function_id=-1, + call_count=1, dtype_to_arguments={}, error=False, - big_endian=(sys.byteorder.lower() == 'big'), + big_endian=(sys.byteorder.lower() == "big"), polling_interval=0, - encoded_units = ()): + encoded_units=(), + ): self.polling_interval = polling_interval - + # flags self.big_endian = big_endian self.error = error - + # header self.call_id = call_id self.function_id = function_id self.call_count = call_count - + # data (numpy arrays) self.ints = [] self.longs = [] @@ -72,64 +77,63 @@ def __init__(self, self.booleans = [] self.pack_data(dtype_to_arguments) - + self.encoded_units = encoded_units - def pack_data(self, dtype_to_arguments): for dtype, attrname in self.dtype_to_message_attribute(): if dtype in dtype_to_arguments: array = pack_array(dtype_to_arguments[dtype], self.call_count, dtype) setattr(self, attrname, array) - + def to_result(self, handle_as_array=False): dtype_to_result = {} for dtype, attrname in self.dtype_to_message_attribute(): result = getattr(self, attrname) if self.call_count > 1 or handle_as_array: - dtype_to_result[dtype] = unpack_array(result , self.call_count, dtype) + dtype_to_result[dtype] = unpack_array(result, self.call_count, dtype) else: dtype_to_result[dtype] = result - + return dtype_to_result - + def dtype_to_message_attribute(self): return ( - ('int32', 'ints'), - ('int64', 'longs'), - ('float32', 'floats'), - ('float64', 'doubles'), - ('bool', 'booleans'), - ('string', 'strings'), + ("int32", "ints"), + ("int64", "longs"), + ("float32", "floats"), + ("float64", "doubles"), + ("bool", "booleans"), + ("string", "strings"), ) - + def receive(self, comm): raise NotImplementedError - + def send(self, comm): raise NotImplementedError - + def set_error(self, message): self.strings = [message] self.error = True - - + + class MPIMessage(AbstractMessage): - + def receive(self, comm): header = self.receive_header(comm) self.receive_content(comm, header) - + def receive_header(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") self.mpi_receive(comm, [header, MPI.INT]) return header - + def receive_content(self, comm, header): # 4 flags as 8bit booleans in 1st 4 bytes of header - # endiannes(not supported by MPI channel), error, unused, unused + # endiannes(not supported by MPI channel), error, unused, unused - flags = header.view(dtype='bool_') + flags = header.view(dtype="bool_") self.big_endian = flags[0] self.error = flags[1] self.is_continued = flags[2] @@ -145,117 +149,119 @@ def receive_content(self, comm, header): number_of_booleans = header[8] number_of_strings = header[9] number_of_units = header[10] - + self.ints = self.receive_ints(comm, number_of_ints) self.longs = self.receive_longs(comm, number_of_longs) self.floats = self.receive_floats(comm, number_of_floats) self.doubles = self.receive_doubles(comm, number_of_doubles) self.booleans = self.receive_booleans(comm, number_of_booleans) self.strings = self.receive_strings(comm, number_of_strings) - + self.encoded_units = self.receive_doubles(comm, number_of_units) - def nonblocking_receive(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") request = self.mpi_nonblocking_receive(comm, [header, MPI.INT]) return async_request.ASyncRequest(request, self, comm, header) - + def receive_doubles(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='d') + result = numpy.empty(total, dtype="d") self.mpi_receive(comm, [result, MPI.DOUBLE]) return result else: return [] - + def receive_ints(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='i') + result = numpy.empty(total, dtype="i") self.mpi_receive(comm, [result, MPI.INT]) return result else: return [] - + def receive_longs(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='int64') + result = numpy.empty(total, dtype="int64") self.mpi_receive(comm, [result, MPI.INTEGER8]) return result else: return [] - + def receive_floats(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='f') + result = numpy.empty(total, dtype="f") self.mpi_receive(comm, [result, MPI.FLOAT]) return result else: return [] - - + def receive_booleans(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='b') - self.mpi_receive(comm, [result, MPI.C_BOOL or MPI.BYTE]) # if C_BOOL null datatype (ie undefined) fallback + result = numpy.empty(total, dtype="b") + self.mpi_receive( + comm, [result, MPI.C_BOOL or MPI.BYTE] + ) # if C_BOOL null datatype (ie undefined) fallback return numpy.logical_not(result == 0) else: return [] - - + def receive_strings(self, comm, total): if total > 0: - sizes = numpy.empty(total, dtype='i') - + sizes = numpy.empty(total, dtype="i") + self.mpi_receive(comm, [sizes, MPI.INT]) - + logger.debug("got %d strings of size %s", total, sizes) - + byte_size = 0 for size in sizes: byte_size = byte_size + size + 1 - + data_bytes = numpy.empty(byte_size, dtype=numpy.uint8) self.mpi_receive(comm, [data_bytes, MPI.CHARACTER]) - + strings = [] begin = 0 for size in sizes: - strings.append(data_bytes[begin:begin + size].tobytes().decode('latin_1')) + strings.append( + data_bytes[begin : begin + size].tobytes().decode("latin_1") + ) begin = begin + size + 1 - + logger.debug("got %d strings of size %s, data = %s", total, sizes, strings) return numpy.array(strings) else: return [] - - + def send(self, comm): - header = numpy.array([ - 0, - self.call_id, - self.function_id, - self.call_count, - len(self.ints) , - len(self.longs) , - len(self.floats) , - len(self.doubles) , - len(self.booleans) , - len(self.strings) , - len(self.encoded_units) - ], dtype='i') - - - flags = header.view(dtype='bool_') + header = numpy.array( + [ + 0, + self.call_id, + self.function_id, + self.call_count, + len(self.ints), + len(self.longs), + len(self.floats), + len(self.doubles), + len(self.booleans), + len(self.strings), + len(self.encoded_units), + ], + dtype="i", + ) + + flags = header.view(dtype="bool_") flags[0] = self.big_endian flags[1] = self.error flags[2] = len(self.encoded_units) > 0 self.send_header(comm, header) self.send_content(comm) - + def send_header(self, comm, header): self.mpi_send(comm, [header, MPI.INT]) - + def send_content(self, comm): self.send_ints(comm, self.ints) self.send_longs(comm, self.longs) @@ -264,102 +270,103 @@ def send_content(self, comm): self.send_booleans(comm, self.booleans) self.send_strings(comm, self.strings) self.send_doubles(comm, self.encoded_units) - def send_ints(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='int32') + sendbuffer = numpy.array(array, dtype="int32") self.mpi_send(comm, [sendbuffer, MPI.INT]) - + def send_longs(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='int64') - self.mpi_send(comm, [sendbuffer, MPI.INTEGER8]) - + sendbuffer = numpy.array(array, dtype="int64") + self.mpi_send(comm, [sendbuffer, MPI.INTEGER8]) + def send_doubles(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='d') + sendbuffer = numpy.array(array, dtype="d") self.mpi_send(comm, [sendbuffer, MPI.DOUBLE]) - + def send_floats(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='f') + sendbuffer = numpy.array(array, dtype="f") self.mpi_send(comm, [sendbuffer, MPI.FLOAT]) - + def send_strings(self, comm, array): if len(array) == 0: return - - lengths = numpy.array( [len(s) for s in array] ,dtype='i') - - chars=(chr(0).join(array)+chr(0)).encode("utf-8") - chars = numpy.frombuffer(chars, dtype='uint8') - if len(chars) != lengths.sum()+len(lengths): - raise Exception("send_strings size mismatch {0} vs {1}".format( len(chars) , lengths.sum()+len(lengths) )) + lengths = numpy.array([len(s) for s in array], dtype="i") + + chars = (chr(0).join(array) + chr(0)).encode("utf-8") + chars = numpy.frombuffer(chars, dtype="uint8") + + if len(chars) != lengths.sum() + len(lengths): + raise Exception( + "send_strings size mismatch {0} vs {1}".format( + len(chars), lengths.sum() + len(lengths) + ) + ) self.mpi_send(comm, [lengths, MPI.INT]) self.mpi_send(comm, [chars, MPI.CHARACTER]) - + def send_booleans(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='b') + sendbuffer = numpy.array(array, dtype="b") self.mpi_send(comm, [sendbuffer, MPI.C_BOOL or MPI.BYTE]) def set_error(self, message): self.strings = [message] self.error = True - + def mpi_nonblocking_receive(self, comm, array): raise NotImplementedError() - + def mpi_receive(self, comm, array): raise NotImplementedError() - + def mpi_send(self, comm, array): raise NotImplementedError() - - + + class ServerSideMPIMessage(MPIMessage): - + def mpi_receive(self, comm, array): request = comm.Irecv(array, source=0, tag=999) request.Wait() - + def mpi_send(self, comm, array): comm.Bcast(array, root=MPI.ROOT) - + def send_header(self, comm, array): requests = [] for rank in range(comm.Get_remote_size()): request = comm.Isend(array, dest=rank, tag=989) requests.append(request) MPI.Request.Waitall(requests) - - + def mpi_nonblocking_receive(self, comm, array): return comm.Irecv(array, source=0, tag=999) def receive_header(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") request = self.mpi_nonblocking_receive(comm, [header, MPI.INT]) if self.polling_interval > 0: is_finished = request.Test() while not is_finished: - time.sleep(self.polling_interval / 1000000.) + time.sleep(self.polling_interval / 1000000.0) is_finished = request.Test() request.Wait() else: request.Wait() return header - - + class ClientSideMPIMessage(MPIMessage): - + def mpi_receive(self, comm, array): comm.Bcast(array, root=0) - + def mpi_send(self, comm, array): comm.Send(array, dest=0, tag=999) @@ -367,22 +374,24 @@ def mpi_nonblocking_receive(self, comm, array): return comm.Irecv(array, source=0, tag=999) def receive_header(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") request = comm.Irecv([header, MPI.INT], source=0, tag=989) if self.polling_interval > 0: is_finished = request.Test() while not is_finished: - time.sleep(self.polling_interval / 1000000.) + time.sleep(self.polling_interval / 1000000.0) is_finished = request.Test() request.Wait() else: request.Wait() return header + MAPPING = {} + def pack_array(array, length, dtype): - if dtype == 'string': + if dtype == "string": if length == 1 and len(array) > 0 and isinstance(array[0], str): return array result = [] @@ -402,171 +411,229 @@ def pack_array(array, length, dtype): result = MAPPING.dtype if len(result) != total_length: result = numpy.empty(length * len(array), dtype=dtype) - else: + else: result = numpy.empty(length * len(array), dtype=dtype) - + for i in range(len(array)): offset = i * length - result[offset:offset + length] = array[i] + result[offset : offset + length] = array[i] return result - + def unpack_array(array, length, dtype=None): result = [] total = len(array) // length for i in range(total): offset = i * length - result.append(array[offset:offset + length]) + result.append(array[offset : offset + length]) return result + class AbstractMessageChannel(OptionalAttributes): """ Abstract base class of all message channel. - + A message channel is used to send and retrieve messages from a remote party. A message channel can also setup the remote party. For example starting an instance of an application using MPI calls. - + The messages are encoded as arguments to the send and retrieve methods. Each message has an id and and optional list of doubles, integers, floats and/or strings. - + """ - + def __init__(self, **options): OptionalAttributes.__init__(self, **options) - + @classmethod - def GDB(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e', 'gdb'] - + def GDB( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-hold", "-display", os.environ["DISPLAY"], "-e", "gdb"] + if immediate_run: - arguments.extend([ '-ex', 'run']) - - arguments.extend(['--args']) - + arguments.extend(["-ex", "run"]) + + arguments.extend(["--args"]) + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - command = 'xterm' + + command = "xterm" return command, arguments @classmethod - def LLDB(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e', 'lldb', '--'] + def LLDB( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-hold", "-display", os.environ["DISPLAY"], "-e", "lldb", "--"] if not interpreter_executable is None: arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) - command = 'xterm' + command = "xterm" return command, arguments @classmethod - def DDD(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - if os.name == 'nt': - arguments = [full_name_of_the_worker, "--args",full_name_of_the_worker] + def DDD( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + if os.name == "nt": + arguments = [full_name_of_the_worker, "--args", full_name_of_the_worker] command = channel.adg_exe return command, arguments else: - arguments = ['-display', os.environ['DISPLAY'], '-e', 'ddd', '--args'] - + arguments = ["-display", os.environ["DISPLAY"], "-e", "ddd", "--args"] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - command = 'xterm' + + command = "xterm" return command, arguments - + @classmethod - def VALGRIND(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def VALGRIND( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): # arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e', 'valgrind', full_name_of_the_worker] arguments = [] - + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - command = 'valgrind' + command = "valgrind" return command, arguments - - + @classmethod - def XTERM(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e'] - + def XTERM( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-hold", "-display", os.environ["DISPLAY"], "-e"] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - command = 'xterm' + + command = "xterm" return command, arguments - @classmethod - def REDIRECT(cls, full_name_of_the_worker, stdoutname, stderrname, command=None, - interpreter_executable=None, run_command_redirected_file=None ): - - fname = run_command_redirected_file or run_command_redirected.__file__ - arguments = [fname , stdoutname, stderrname] - + def REDIRECT( + cls, + full_name_of_the_worker, + stdoutname, + stderrname, + command=None, + interpreter_executable=None, + run_command_redirected_file=None, + ): + + fname = run_command_redirected_file or run_command_redirected.__file__ + arguments = [fname, stdoutname, stderrname] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - if command is None : + + if command is None: command = sys.executable - + return command, arguments - + @classmethod - def GDBR(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def GDBR( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): "remote gdb, can run without xterm" - - arguments = ['localhost:{0}'.format(channel.debugger_port)] - + + arguments = ["localhost:{0}".format(channel.debugger_port)] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - + command = channel.gdbserver_exe return command, arguments - + @classmethod - def NODEBUGGER(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def NODEBUGGER( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): if not interpreter_executable is None: return interpreter_executable, [full_name_of_the_worker] else: return full_name_of_the_worker, [] - - + @classmethod - def STRACE(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-ostrace-out', '-ff'] + def STRACE( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-ostrace-out", "-ff"] if not interpreter_executable is None: arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) - command = 'strace' + command = "strace" return command, arguments - + @classmethod - def CUSTOM(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def CUSTOM( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): arguments = list(shlex.split(channel.custom_args)) if not interpreter_executable is None: arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) command = channel.custom_exe return command, arguments - - + @classmethod def is_multithreading_supported(cls): return True @@ -575,111 +642,132 @@ def is_multithreading_supported(cls): def initialize_mpi(self): """Is MPI initialized in the code or not. Defaults to True if MPI is available""" return config.mpi.is_enabled - - @option(type='string', sections=("channel",)) + + @option(type="string", sections=("channel",)) def worker_code_suffix(self): - return '' - - @option(type='string', sections=("channel",)) + return "" + + @option(type="string", sections=("channel",)) def worker_code_prefix(self): - return '' - - @option(type='string', sections=("channel",)) + return "" + + @option(type="string", sections=("channel",)) def worker_code_directory(self): - return '' + return "" @option(type="boolean", sections=("channel",)) def can_redirect_output(self): return True - + @option(sections=("channel",)) def python_exe_for_redirection(self): return None - - + @option(type="int", sections=("channel",)) def debugger_port(self): return 4343 - + @option(type="string", sections=("channel",)) def gdbserver_exe(self): - return 'gdbserver' - + return "gdbserver" + @option(type="string", sections=("channel",)) def adg_exe(self): - return 'adg.exe' - + return "adg.exe" + @option(type="string", sections=("channel",)) def custom_exe(self): - return 'mintty.exe' - + return "mintty.exe" + @option(type="string", sections=("channel",)) def custom_args(self): - return '--hold -e gdb --args' + return "--hold -e gdb --args" - @option(type='boolean', sections=("channel",)) + @option(type="boolean", sections=("channel",)) def debugger_immediate_run(self): return True - - @option(type='boolean', sections=("channel",)) + + @option(type="boolean", sections=("channel",)) def must_check_if_worker_is_up_to_date(self): return True - @option(type='boolean', sections=("channel",)) + @option(type="boolean", sections=("channel",)) def check_worker_location(self): return True - + @option(type="int", sections=("channel",)) def number_of_workers(self): return 1 - + def get_amuse_root_directory(self): return self.amuse_root_dir - - @option(type="string", sections=('data',)) - def amuse_root_dir(self): # needed for location of data, so same as in support.__init__ + + @option(type="string", sections=("data",)) + def amuse_root_dir( + self, + ): # needed for location of data, so same as in support.__init__ return get_amuse_root_dir() - + def check_if_worker_is_up_to_date(self, object): if not self.must_check_if_worker_is_up_to_date: return - + name_of_the_compiled_file = self.full_name_of_the_worker modificationtime_of_worker = os.stat(name_of_the_compiled_file).st_mtime my_class = type(object) for x in dir(my_class): - if x.startswith('__'): + if x.startswith("__"): continue value = getattr(my_class, x) - if hasattr(value, 'crc32'): - is_up_to_date = value.is_compiled_file_up_to_date(modificationtime_of_worker) + if hasattr(value, "crc32"): + is_up_to_date = value.is_compiled_file_up_to_date( + modificationtime_of_worker + ) if not is_up_to_date: - raise exceptions.CodeException("""The worker code of the '{0}' interface class is not up to date. + raise exceptions.CodeException( + """The worker code of the '{0}' interface class is not up to date. Please do a 'make clean; make' in the root directory. -""".format(type(object).__name__)) +""".format( + type(object).__name__ + ) + ) def get_full_name_of_the_worker(self, type): if os.path.isabs(self.name_of_the_worker): - full_name_of_the_worker=self.name_of_the_worker - + full_name_of_the_worker = self.name_of_the_worker + if not self.check_worker_location: return full_name_of_the_worker - + if not os.path.exists(full_name_of_the_worker): - raise exceptions.CodeException("The worker path has been specified, but it is not found: \n{0}".format(full_name_of_the_worker)) + raise exceptions.CodeException( + "The worker path has been specified, but it is not found: \n{0}".format( + full_name_of_the_worker + ) + ) if not os.access(full_name_of_the_worker, os.X_OK): - raise exceptions.CodeException("The worker application exists, but it is not executable.\n{0}".format(full_name_of_the_worker)) - + raise exceptions.CodeException( + "The worker application exists, but it is not executable.\n{0}".format( + full_name_of_the_worker + ) + ) + return full_name_of_the_worker - - exe_name = self.worker_code_prefix + self.name_of_the_worker + self.worker_code_suffix + + exe_name = ( + self.worker_code_prefix + self.name_of_the_worker + self.worker_code_suffix + ) if not self.check_worker_location: if len(self.worker_code_directory) > 0: - full_name_of_the_worker = os.path.join(self.worker_code_directory, exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.join( + self.worker_code_directory, exe_name + ) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) return full_name_of_the_worker else: raise Exception("Must provide a worker_code_directory") @@ -687,22 +775,32 @@ def get_full_name_of_the_worker(self, type): tried_workers = [] directory = os.path.dirname(inspect.getfile(type)) - full_name_of_the_worker = os.path.join(directory, '..','..','_workers', exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.join( + directory, "..", "..", "_workers", exe_name + ) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) - + if len(self.worker_code_directory) > 0: full_name_of_the_worker = os.path.join(self.worker_code_directory, exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) - + directory_of_this_module = os.path.dirname(os.path.dirname(__file__)) - full_name_of_the_worker = os.path.join(directory_of_this_module, '_workers', exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.join( + directory_of_this_module, "_workers", exe_name + ) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) @@ -711,40 +809,51 @@ def get_full_name_of_the_worker(self, type): while not current_type.__bases__[0] is object: directory_of_this_module = os.path.dirname(inspect.getfile(current_type)) full_name_of_the_worker = os.path.join(directory_of_this_module, exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) current_type = current_type.__bases__[0] - raise exceptions.CodeException("The worker application does not exist, it should be at: \n{0}".format('\n'.join(tried_workers))) - - def send_message(self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units = None): + raise exceptions.CodeException( + "The worker application does not exist, it should be at: \n{0}".format( + "\n".join(tried_workers) + ) + ) + + def send_message( + self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units=None + ): pass - - def recv_message(self, call_id=0, function_id=-1, handle_as_array=False, has_units = False): + + def recv_message( + self, call_id=0, function_id=-1, handle_as_array=False, has_units=False + ): pass - - def nonblocking_recv_message(self, call_id=0, function_id=-1, handle_as_array=False): + + def nonblocking_recv_message( + self, call_id=0, function_id=-1, handle_as_array=False + ): pass - + def start(self): pass - + def stop(self): pass def is_active(self): return True - + @classmethod def is_root(self): return True - + def is_polling_supported(self): return False - - + def determine_length_from_data(self, dtype_to_arguments): def get_length(type_and_values): argument_type, argument_values = type_and_values @@ -757,71 +866,81 @@ def get_length(type_and_values): except: result = max(result, 1) return result - - - + lengths = [get_length(x) for x in dtype_to_arguments.items()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - def split_message(self, call_id, function_id, call_count, dtype_to_arguments, encoded_units = ()): - - if call_count<=1: + def split_message( + self, call_id, function_id, call_count, dtype_to_arguments, encoded_units=() + ): + + if call_count <= 1: raise Exception("split message called with call_count<=1") - + dtype_to_result = {} - - ndone=0 - while ndone>> is_mpd_running() True - - + + """ if not MpiChannel.is_supported(): return True - + MpiChannel.ensure_mpi_initialized() - + name_of_the_vendor, version = MPI.get_vendor() - if name_of_the_vendor == 'MPICH2': + if name_of_the_vendor == "MPICH2": must_check_mpd = True - if 'AMUSE_MPD_CHECK' in os.environ: - must_check_mpd = os.environ['AMUSE_MPD_CHECK'] == '1' - if 'PMI_PORT' in os.environ: + if "AMUSE_MPD_CHECK" in os.environ: + must_check_mpd = os.environ["AMUSE_MPD_CHECK"] == "1" + if "PMI_PORT" in os.environ: must_check_mpd = False - if 'PMI_RANK' in os.environ: + if "PMI_RANK" in os.environ: must_check_mpd = False - if 'HYDRA_CONTROL_FD' in os.environ: + if "HYDRA_CONTROL_FD" in os.environ: must_check_mpd = False - + if not must_check_mpd: return True try: - process = Popen(['mpdtrace'], stdout=PIPE, stderr=PIPE) + process = Popen(["mpdtrace"], stdout=PIPE, stderr=PIPE) (output_string, error_string) = process.communicate() return not (process.returncode == 255) except OSError as ex: @@ -882,13 +1001,14 @@ def is_mpd_running(): class MpiChannel(AbstractMessageChannel): """ Message channel based on MPI calls to send and recv the messages - + :argument name_of_the_worker: Name of the application to start :argument number_of_workers: Number of parallel processes :argument legacy_interface_type: Type of the legacy interface :argument debug_with_gdb: If True opens an xterm with a gdb to debug the remote process :argument hostname: Name of the node to run the application on """ + _mpi_is_broken_after_possible_code_crash = False _intercomms_to_disconnect = [] _is_registered = False @@ -896,68 +1016,81 @@ class MpiChannel(AbstractMessageChannel): _scheduler_index = 0 _scheduler_initialized = False - - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, **options): + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + **options, + ): AbstractMessageChannel.__init__(self, **options) - + self.inuse_semaphore = threading.Semaphore() # logging.basicConfig(level=logging.WARN) # logger.setLevel(logging.DEBUG) # logging.getLogger("code").setLevel(logging.DEBUG) - + self.ensure_mpi_initialized() - + self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + if not legacy_interface_type is None: - self.full_name_of_the_worker = self.get_full_name_of_the_worker(legacy_interface_type) + self.full_name_of_the_worker = self.get_full_name_of_the_worker( + legacy_interface_type + ) else: self.full_name_of_the_worker = self.name_of_the_worker - + if self.check_mpi: if not is_mpd_running(): - raise exceptions.CodeException("The mpd daemon is not running, please make sure it is started before starting this code") - + raise exceptions.CodeException( + "The mpd daemon is not running, please make sure it is started before starting this code" + ) + if self._mpi_is_broken_after_possible_code_crash: - raise exceptions.CodeException("Another code has crashed, cannot spawn a new code, please stop the script and retry") + raise exceptions.CodeException( + "Another code has crashed, cannot spawn a new code, please stop the script and retry" + ) if not self.hostname is None: self.info = MPI.Info.Create() - self.info['host'] = self.hostname + self.info["host"] = self.hostname else: if self.job_scheduler: - self.info = self.get_info_from_job_scheduler(self.job_scheduler, self.number_of_workers) + self.info = self.get_info_from_job_scheduler( + self.job_scheduler, self.number_of_workers + ) else: self.info = MPI.Info.Create() - - for key,value in self.mpi_info_options.items(): - self.info[key]=value - + + for key, value in self.mpi_info_options.items(): + self.info[key] = value + self.cached = None self.intercomm = None self._is_inuse = False self._communicated_splitted_message = False logger.debug("MPI channel created with info items: %s", str(self.info.items())) - @classmethod def ensure_mpi_initialized(cls): global MPI - + if MPI is None: import mpi4py.MPI + MPI = mpi4py.MPI cls.register_finalize_code() @classmethod def is_threaded(cls): - #We want this for backwards compatibility with mpi4py versions < 2.0.0 - #currently unused after Init/Init_threaded was removed from - #this module. + # We want this for backwards compatibility with mpi4py versions < 2.0.0 + # currently unused after Init/Init_threaded was removed from + # this module. from mpi4py import rc + try: return rc.threaded except AttributeError: @@ -968,7 +1101,7 @@ def register_finalize_code(cls): if not cls._is_registered: atexit.register(cls.finialize_mpi_atexit) cls._is_registered = True - + @classmethod def finialize_mpi_atexit(cls): if not MPI.Is_initialized(): @@ -978,26 +1111,26 @@ def finialize_mpi_atexit(cls): try: for x in cls._intercomms_to_disconnect: x.Disconnect() - + except MPI.Exception as ex: return - + @classmethod def is_multithreading_supported(cls): return MPI.Query_thread() == MPI.THREAD_MULTIPLE - + @option(type="boolean", sections=("channel",)) def check_mpi(self): return True - + @option(type="boolean", sections=("channel",)) def debug_with_gdb(self): return False - + @option(sections=("channel",)) def hostname(self): return None - + @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" @@ -1006,78 +1139,83 @@ def debugger(self): @option(type="dict", sections=("channel",)) def mpi_info_options(self): return dict() - + @option(type="int", sections=("channel",)) def max_message_length(self): """ For calls to functions that can handle arrays, MPI messages may get too long for large N. The MPI channel will split long messages into blocks of size max_message_length. - """ + """ return 1000000 - @late def redirect_stdout_file(self): return "/dev/null" - + @late def redirect_stderr_file(self): return "/dev/null" - + @late def debugger_method(self): return self.DEBUGGERS[self.debugger] - + @classmethod def is_supported(cls): - if hasattr(config, 'mpi') and hasattr(config.mpi, 'is_enabled'): + if hasattr(config, "mpi") and hasattr(config.mpi, "is_enabled"): if not config.mpi.is_enabled: return False try: from mpi4py import MPI + return True except ImportError: return False - @option(type="boolean", sections=("channel",)) def can_redirect_output(self): name_of_the_vendor, version = MPI.get_vendor() - if name_of_the_vendor == 'MPICH2': - if 'MPISPAWN_ARGV_0' in os.environ: + if name_of_the_vendor == "MPICH2": + if "MPISPAWN_ARGV_0" in os.environ: return False return True - - + @option(type="boolean", sections=("channel",)) def must_disconnect_on_stop(self): name_of_the_vendor, version = MPI.get_vendor() - if name_of_the_vendor == 'MPICH2': - if 'MPISPAWN_ARGV_0' in os.environ: + if name_of_the_vendor == "MPICH2": + if "MPISPAWN_ARGV_0" in os.environ: return False return True - + @option(type="int", sections=("channel",)) def polling_interval_in_milliseconds(self): return 0 - + @classmethod def is_root(cls): cls.ensure_mpi_initialized() return MPI.COMM_WORLD.rank == 0 - + def start(self): logger.debug("starting mpi worker process") logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) - + if not self.debugger_method is None: - command, arguments = self.debugger_method(self.full_name_of_the_worker, self, - interpreter_executable=self.interpreter_executable, immediate_run=self.debugger_immediate_run) + command, arguments = self.debugger_method( + self.full_name_of_the_worker, + self, + interpreter_executable=self.interpreter_executable, + immediate_run=self.debugger_immediate_run, + ) else: - if not self.can_redirect_output or (self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none'): - + if not self.can_redirect_output or ( + self.redirect_stdout_file == "none" + and self.redirect_stderr_file == "none" + ): + if self.interpreter_executable is None: command = self.full_name_of_the_worker arguments = None @@ -1085,16 +1223,28 @@ def start(self): command = self.interpreter_executable arguments = [self.full_name_of_the_worker] else: - command, arguments = self.REDIRECT(self.full_name_of_the_worker, self.redirect_stdout_file, self.redirect_stderr_file, command=self.python_exe_for_redirection, interpreter_executable=self.interpreter_executable) - - logger.debug("spawning %d mpi processes with command `%s`, arguments `%s` and environment '%s'", self.number_of_workers, command, arguments, os.environ) + command, arguments = self.REDIRECT( + self.full_name_of_the_worker, + self.redirect_stdout_file, + self.redirect_stderr_file, + command=self.python_exe_for_redirection, + interpreter_executable=self.interpreter_executable, + ) + + logger.debug( + "spawning %d mpi processes with command `%s`, arguments `%s` and environment '%s'", + self.number_of_workers, + command, + arguments, + os.environ, + ) - self.intercomm = MPI.COMM_SELF.Spawn(command, arguments, self.number_of_workers, info=self.info) + self.intercomm = MPI.COMM_SELF.Spawn( + command, arguments, self.number_of_workers, info=self.info + ) logger.debug("worker spawn done") - - - + def stop(self): if not self.intercomm is None: try: @@ -1105,9 +1255,9 @@ def stop(self): except MPI.Exception as ex: if ex.error_class == MPI.ERR_OTHER: type(self)._mpi_is_broken_after_possible_code_crash = True - + self.intercomm = None - + def determine_length_from_datax(self, dtype_to_arguments): def get_length(x): if x: @@ -1117,53 +1267,60 @@ def get_length(x): except: return 1 return 1 - - - + lengths = [get_length(x) for x in dtype_to_arguments.values()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = ()): - + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=() + ): + if self.intercomm is None: - raise exceptions.CodeException("You've tried to send a message to a code that is not running") - + raise exceptions.CodeException( + "You've tried to send a message to a code that is not running" + ) + call_count = self.determine_length_from_data(dtype_to_arguments) - + if call_count > self.max_message_length: - self.split_message(call_id, function_id, call_count, dtype_to_arguments, encoded_units) + self.split_message( + call_id, function_id, call_count, dtype_to_arguments, encoded_units + ) else: if self.is_inuse(): - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) self.inuse_semaphore.acquire() try: if self._is_inuse: - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) self._is_inuse = True finally: self.inuse_semaphore.release() message = ServerSideMPIMessage( - call_id, function_id, - call_count, dtype_to_arguments, - encoded_units = encoded_units + call_id, + function_id, + call_count, + dtype_to_arguments, + encoded_units=encoded_units, ) message.send(self.intercomm) + def recv_message(self, call_id, function_id, handle_as_array, has_units=False): - def recv_message(self, call_id, function_id, handle_as_array, has_units = False): - if self._communicated_splitted_message: x = self._merged_results_splitted_message self._communicated_splitted_message = False del self._merged_results_splitted_message return x - + message = ServerSideMPIMessage( polling_interval=self.polling_interval_in_milliseconds * 1000 ) @@ -1176,74 +1333,103 @@ def recv_message(self, call_id, function_id, handle_as_array, has_units = False) self.inuse_semaphore.acquire() try: if not self._is_inuse: - raise exceptions.CodeException("You've tried to recv a message to a code that is not handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to recv a message to a code that is not handling a message, this is not correct" + ) self._is_inuse = False finally: self.inuse_semaphore.release() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] if len(message.strings) > 0 else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." + self.stop() + error_message += " - code probably died, sorry." raise exceptions.CodeException("Error in code: " + error_message) if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units = False): + def nonblocking_recv_message( + self, call_id, function_id, handle_as_array, has_units=False + ): request = ServerSideMPIMessage().nonblocking_receive(self.intercomm) + def handle_result(function): self._is_inuse = False - + message = function() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] + if len(message.strings) > 0 + else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." - raise exceptions.CodeException("Error in (asynchronous) communication with worker: " + error_message) - + self.stop() + error_message += " - code probably died, sorry." + raise exceptions.CodeException( + "Error in (asynchronous) communication with worker: " + + error_message + ) + if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) - + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) + if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) request.add_result_handler(handle_result) - + return request - + def is_active(self): return self.intercomm is not None - + def is_inuse(self): return self._is_inuse - + def is_polling_supported(self): return True - + def __getstate__(self): - return {'state':'empty'} - + return {"state": "empty"} + def __setstate__(self, state): self.info = MPI.INFO_NULL self.cached = None @@ -1257,19 +1443,23 @@ def job_scheduler(self): """Name of the job scheduler to use when starting the code, if given will use job scheduler to find list of hostnames for spawning""" return "" - def get_info_from_job_scheduler(self, name, number_of_workers = 1): + def get_info_from_job_scheduler(self, name, number_of_workers=1): if name == "slurm": return self.get_info_from_slurm(number_of_workers) return MPI.INFO_NULL @classmethod def get_info_from_slurm(cls, number_of_workers): - has_slurm_env_variables = 'SLURM_NODELIST' in os.environ and 'SLURM_TASKS_PER_NODE' in os.environ + has_slurm_env_variables = ( + "SLURM_NODELIST" in os.environ and "SLURM_TASKS_PER_NODE" in os.environ + ) if not has_slurm_env_variables: return MPI.INFO_NULL if not cls._scheduler_initialized: - nodelist = slurm.parse_slurm_nodelist(os.environ['SLURM_NODELIST']) - tasks_per_node = slurm.parse_slurm_tasks_per_node(os.environ['SLURM_TASKS_PER_NODE']) + nodelist = slurm.parse_slurm_nodelist(os.environ["SLURM_NODELIST"]) + tasks_per_node = slurm.parse_slurm_tasks_per_node( + os.environ["SLURM_TASKS_PER_NODE"] + ) all_nodes = [] for node, tasks in zip(nodelist, tasks_per_node): for _ in range(tasks): @@ -1281,29 +1471,30 @@ def get_info_from_slurm(cls, number_of_workers): hostnames = [] count = 0 while count < number_of_workers: - hostnames.append(cls._scheduler_nodes[cls._scheduler_index]) - count += 1 - cls._scheduler_index += 1 - if cls._scheduler_index >= len(cls._scheduler_nodes): - cls._scheduler_index = 0 - host = ','.join(hostnames) - print("HOST:", host, cls._scheduler_index, os.environ['SLURM_TASKS_PER_NODE']) + hostnames.append(cls._scheduler_nodes[cls._scheduler_index]) + count += 1 + cls._scheduler_index += 1 + if cls._scheduler_index >= len(cls._scheduler_nodes): + cls._scheduler_index = 0 + host = ",".join(hostnames) + print("HOST:", host, cls._scheduler_index, os.environ["SLURM_TASKS_PER_NODE"]) info = MPI.Info.Create() - info['host'] = host # actually in mpich and openmpi, the host parameter is interpreted as a comma separated list of host names, + info["host"] = ( + host # actually in mpich and openmpi, the host parameter is interpreted as a comma separated list of host names, + ) return info - class MultiprocessingMPIChannel(AbstractMessageChannel): """ - Message channel based on JSON messages. - + Message channel based on JSON messages. + The remote party functions as a message forwarder. Each message is forwarded to a real application using MPI. This is message channel is a lot slower than the MPI message channel. But, it is useful during testing with the MPICH2 nemesis channel. As the tests will run as one - application on one node they will cause oversaturation + application on one node they will cause oversaturation of the processor(s) on the node. Each legacy code will call the MPI_FINALIZE call and this call will wait for the MPI_FINALIZE call of the main test process. During @@ -1312,52 +1503,63 @@ class MultiprocessingMPIChannel(AbstractMessageChannel): instead of the normal MPIChannel. Then, part of the test is performed in a separate application (at least as MPI sees it) and this part can be stopped after each - sub-test, thus removing unneeded applications. + sub-test, thus removing unneeded applications. """ - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, **options): + + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + **options, + ): AbstractMessageChannel.__init__(self, **options) - + self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + if not legacy_interface_type is None: - self.full_name_of_the_worker = self.get_full_name_of_the_worker(legacy_interface_type) + self.full_name_of_the_worker = self.get_full_name_of_the_worker( + legacy_interface_type + ) else: self.full_name_of_the_worker = self.name_of_the_worker - + self.process = None - + @option(type="boolean") def debug_with_gdb(self): return False - + @option def hostname(self): return None - + def start(self): - name_of_dir = "/tmp/amuse_" + os.getenv('USER') - self.name_of_the_socket, self.server_socket = self._createAServerUNIXSocket(name_of_dir) + name_of_dir = "/tmp/amuse_" + os.getenv("USER") + self.name_of_the_socket, self.server_socket = self._createAServerUNIXSocket( + name_of_dir + ) environment = os.environ.copy() - - if 'PYTHONPATH' in environment: - environment['PYTHONPATH'] = environment['PYTHONPATH'] + ':' + self._extra_path_item(__file__) + + if "PYTHONPATH" in environment: + environment["PYTHONPATH"] = ( + environment["PYTHONPATH"] + ":" + self._extra_path_item(__file__) + ) else: - environment['PYTHONPATH'] = self._extra_path_item(__file__) - - + environment["PYTHONPATH"] = self._extra_path_item(__file__) + all_options = {} for x in self.iter_options(): all_options[x.name] = getattr(self, x.name) - - + template = """from {3} import {4} o = {1!r} m = channel.MultiprocessingMPIChannel('{0}',**o) m.run_mpi_channel('{2}')""" modulename = type(self).__module__ - packagagename, thismodulename = modulename.rsplit('.', 1) - + packagagename, thismodulename = modulename.rsplit(".", 1) + code_string = template.format( self.full_name_of_the_worker, all_options, @@ -1367,19 +1569,25 @@ def start(self): ) self.process = Popen([sys.executable, "-c", code_string], env=environment) self.client_socket, undef = self.server_socket.accept() - + def is_active(self): return self.process is not None - + def stop(self): - self._send(self.client_socket, ('stop', (),)) - result = self._recv(self.client_socket) + self._send( + self.client_socket, + ( + "stop", + (), + ), + ) + result = self._recv(self.client_socket) self.process.wait() self.client_socket.close() self.server_socket.close() self._remove_socket(self.name_of_the_socket) self.process = None - + def run_mpi_channel(self, name_of_the_socket): channel = MpiChannel(self.full_name_of_the_worker, **self._local_options) channel.start() @@ -1389,39 +1597,55 @@ def run_mpi_channel(self, name_of_the_socket): while is_running: message, args = self._recv(socket) result = None - if message == 'stop': + if message == "stop": channel.stop() is_running = False - if message == 'send_message': + if message == "send_message": result = channel.send_message(*args) - if message == 'recv_message': + if message == "recv_message": result = channel.recv_message(*args) self._send(socket, result) finally: socket.close() - - def send_message(self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units = ()): - self._send(self.client_socket, ('send_message', (call_id, function_id, dtype_to_arguments),)) + + def send_message( + self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units=() + ): + self._send( + self.client_socket, + ( + "send_message", + (call_id, function_id, dtype_to_arguments), + ), + ) result = self._recv(self.client_socket) return result - def recv_message(self, call_id=0, function_id=-1, handle_as_array=False, has_units=False): - self._send(self.client_socket, ('recv_message', (call_id, function_id, handle_as_array),)) - result = self._recv(self.client_socket) + def recv_message( + self, call_id=0, function_id=-1, handle_as_array=False, has_units=False + ): + self._send( + self.client_socket, + ( + "recv_message", + (call_id, function_id, handle_as_array), + ), + ) + result = self._recv(self.client_socket) return result - + def _send(self, client_socket, message): message_string = pickle.dumps(message) header = struct.pack("i", len(message_string)) client_socket.sendall(header) client_socket.sendall(message_string) - + def _recv(self, client_socket): header = self._receive_all(client_socket, 4) length = struct.unpack("i", header) message_string = self._receive_all(client_socket, length[0]) return pickle.loads(message_string) - + def _receive_all(self, client_socket, number_of_bytes): block_size = 4096 bytes_left = number_of_bytes @@ -1433,18 +1657,17 @@ def _receive_all(self, client_socket, number_of_bytes): blocks.append(block) bytes_left -= len(block) return bytearray().join(blocks) - - + def _createAServerUNIXSocket(self, name_of_the_directory, name_of_the_socket=None): import uuid import socket - + if name_of_the_socket == None: name_of_the_socket = os.path.join(name_of_the_directory, str(uuid.uuid1())) - + if not os.path.exists(name_of_the_directory): os.makedirs(name_of_the_directory) - + server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._remove_socket(name_of_the_socket) server_socket.bind(name_of_the_socket) @@ -1453,19 +1676,20 @@ def _createAServerUNIXSocket(self, name_of_the_directory, name_of_the_socket=Non def _createAClientUNIXSocket(self, name_of_the_socket): import socket + client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - # client_socket.settimeout(0)header + # client_socket.settimeout(0)header client_socket.connect(name_of_the_socket) return client_socket - + def _remove_socket(self, name_of_the_socket): try: os.remove(name_of_the_socket) except OSError: pass - + def _extra_path_item(self, path_of_the_module): - result = '' + result = "" for x in sys.path: if path_of_the_module.startswith(x): if len(x) > len(result): @@ -1483,58 +1707,60 @@ def check_mpi(self): class SocketMessage(AbstractMessage): - + def _receive_all(self, nbytes, thesocket): # logger.debug("receiving %d bytes", nbytes) - + result = [] - + while nbytes > 0: chunk = min(nbytes, 10240) data_bytes = thesocket.recv(chunk) - + if len(data_bytes) == 0: raise exceptions.CodeException("lost connection to code") - + result.append(data_bytes) nbytes -= len(data_bytes) # logger.debug("got %d bytes, result length = %d", len(data_bytes), len(result)) - + if len(result) > 0: return type(result[0])().join(result) else: return b"" - + def receive(self, socket): - + # logger.debug("receiving message") - + header_bytes = self._receive_all(44, socket) - + flags = numpy.frombuffer(header_bytes, dtype="b", count=4, offset=0) - + if flags[0] != self.big_endian: - raise exceptions.CodeException("endianness in message does not match native endianness") - + raise exceptions.CodeException( + "endianness in message does not match native endianness" + ) + if flags[1]: self.error = True else: self.error = False - + header = numpy.copy(numpy.frombuffer(header_bytes, dtype="i", offset=0)) - + # logger.debug("receiving message with flags %s and header %s", flags, header) # id of this call self.call_id = header[1] - + # function ID self.function_id = header[2] - + # number of calls in this message self.call_count = header[3] - + # number of X's in TOTAL number_of_ints = header[4] number_of_longs = header[5] @@ -1551,115 +1777,114 @@ def receive(self, socket): self.booleans = self.receive_booleans(socket, number_of_booleans) self.strings = self.receive_strings(socket, number_of_strings) self.encoded_units = self.receive_doubles(socket, number_of_units) - + # logger.debug("message received") - def receive_ints(self, socket, count): if count > 0: nbytes = count * 4 # size of int - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='int32')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="int32")) + return result else: - return [] - + return [] + def receive_longs(self, socket, count): if count > 0: nbytes = count * 8 # size of long - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='int64')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="int64")) + return result else: return [] - - + def receive_floats(self, socket, count): if count > 0: nbytes = count * 4 # size of float - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='f4')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="f4")) + return result else: return [] - - + def receive_doubles(self, socket, count): if count > 0: nbytes = count * 8 # size of double - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='f8')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="f8")) + return result else: return [] - def receive_booleans(self, socket, count): if count > 0: nbytes = count * 1 # size of boolean/byte - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='b')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="b")) + return result else: return [] - - + def receive_strings(self, socket, count): if count > 0: lengths = self.receive_ints(socket, count) - + total = lengths.sum() + len(lengths) - + data_bytes = self._receive_all(total, socket) strings = [] begin = 0 for size in lengths: - strings.append(data_bytes[begin:begin + size].decode('utf-8')) + strings.append(data_bytes[begin : begin + size].decode("utf-8")) begin = begin + size + 1 return numpy.array(strings) else: return [] - + def nonblocking_receive(self, socket): return async_request.ASyncSocketRequest(self, socket) - - + def send(self, socket): - - flags = numpy.array([self.big_endian, self.error, len(self.encoded_units) > 0, False], dtype="b") - - header = numpy.array([ - self.call_id, - self.function_id, - self.call_count, - len(self.ints), - len(self.longs), - len(self.floats), - len(self.doubles), - len(self.booleans), - len(self.strings), - len(self.encoded_units), - ], dtype='i') - + + flags = numpy.array( + [self.big_endian, self.error, len(self.encoded_units) > 0, False], dtype="b" + ) + + header = numpy.array( + [ + self.call_id, + self.function_id, + self.call_count, + len(self.ints), + len(self.longs), + len(self.floats), + len(self.doubles), + len(self.booleans), + len(self.strings), + len(self.encoded_units), + ], + dtype="i", + ) + # logger.debug("sending message with flags %s and header %s", flags, header) - + socket.sendall(flags.tobytes()) socket.sendall(header.tobytes()) @@ -1671,127 +1896,148 @@ def send(self, socket): self.send_booleans(socket, self.booleans) self.send_strings(socket, self.strings) self.send_doubles(socket, self.encoded_units) - + # logger.debug("message send") def send_doubles(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='f8') + data_buffer = numpy.array(array, dtype="f8") socket.sendall(data_buffer.tobytes()) - + def send_ints(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='int32') + data_buffer = numpy.array(array, dtype="int32") socket.sendall(data_buffer.tobytes()) - + def send_floats(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='f4') + data_buffer = numpy.array(array, dtype="f4") socket.sendall(data_buffer.tobytes()) - + def send_strings(self, socket, array): if len(array) > 0: - - lengths = numpy.array( [len(s) for s in array] ,dtype='int32') - chars=(chr(0).join(array)+chr(0)).encode("utf-8") - - if len(chars) != lengths.sum()+len(lengths): - raise Exception("send_strings size mismatch {0} vs {1}".format( len(chars) , lengths.sum()+len(lengths) )) + + lengths = numpy.array([len(s) for s in array], dtype="int32") + chars = (chr(0).join(array) + chr(0)).encode("utf-8") + + if len(chars) != lengths.sum() + len(lengths): + raise Exception( + "send_strings size mismatch {0} vs {1}".format( + len(chars), lengths.sum() + len(lengths) + ) + ) self.send_ints(socket, lengths) socket.sendall(chars) - + def send_booleans(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='b') + data_buffer = numpy.array(array, dtype="b") socket.sendall(data_buffer.tobytes()) def send_longs(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='int64') + data_buffer = numpy.array(array, dtype="int64") socket.sendall(data_buffer.tobytes()) class SocketChannel(AbstractMessageChannel): - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, - remote_env=None, **options): + + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + remote_env=None, + **options, + ): AbstractMessageChannel.__init__(self, **options) - - #logging.getLogger().setLevel(logging.DEBUG) - + + # logging.getLogger().setLevel(logging.DEBUG) + logger.debug("initializing SocketChannel with options %s", options) - + # self.name_of_the_worker = name_of_the_worker + "_sockets" self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + if self.hostname == None: - self.hostname="localhost" + self.hostname = "localhost" - if self.hostname not in ['localhost',socket.gethostname()]: - self.remote=True - self.must_check_if_worker_is_up_to_date=False + if self.hostname not in ["localhost", socket.gethostname()]: + self.remote = True + self.must_check_if_worker_is_up_to_date = False else: - self.remote=False - + self.remote = False + self.id = 0 - + if not legacy_interface_type is None: - self.full_name_of_the_worker = self.get_full_name_of_the_worker(legacy_interface_type) + self.full_name_of_the_worker = self.get_full_name_of_the_worker( + legacy_interface_type + ) else: self.full_name_of_the_worker = self.name_of_the_worker - + logger.debug("full name of worker is %s", self.full_name_of_the_worker) - + self._is_inuse = False self._communicated_splitted_message = False self.socket = None - - self.remote_env=remote_env + + self.remote_env = remote_env @option(sections=("channel",)) def mpiexec(self): """mpiexec with arguments""" if len(config.mpi.mpiexec): return config.mpi.mpiexec - return '' + return "" @option(sections=("channel",)) def mpiexec_number_of_workers_flag(self): """flag to use, so that the number of workers are defined""" - return '-n' + return "-n" @late def debugger_method(self): return self.DEBUGGERS[self.debugger] - + def accept_worker_connection(self, server_socket, process): - #wait for the worker to connect. check if the process is still running once in a while + # wait for the worker to connect. check if the process is still running once in a while for i in range(0, 60): - #logger.debug("accepting connection") + # logger.debug("accepting connection") try: server_socket.settimeout(1.0) return server_socket.accept() except socket.timeout: - #update and read returncode + # update and read returncode if process.poll() is not None: - raise exceptions.CodeException('could not connect to worker, worker process terminated') - #logger.error("worker not connecting, waiting...") - - raise exceptions.CodeException('worker still not started after 60 seconds') + raise exceptions.CodeException( + "could not connect to worker, worker process terminated" + ) + # logger.error("worker not connecting, waiting...") + + raise exceptions.CodeException("worker still not started after 60 seconds") - def generate_command_and_arguments(self,server_address,port): + def generate_command_and_arguments(self, server_address, port): arguments = [] - + if not self.debugger_method is None: - command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) + command, arguments = self.debugger_method( + self.full_name_of_the_worker, + self, + interpreter_executable=self.interpreter_executable, + ) else: - if self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none': - + if ( + self.redirect_stdout_file == "none" + and self.redirect_stderr_file == "none" + ): + if self.interpreter_executable is None: command = self.full_name_of_the_worker arguments = [] @@ -1799,9 +2045,15 @@ def generate_command_and_arguments(self,server_address,port): command = self.interpreter_executable arguments = [self.full_name_of_the_worker] else: - command, arguments = self.REDIRECT(self.full_name_of_the_worker, self.redirect_stdout_file, self.redirect_stderr_file, command=self.python_exe_for_redirection, interpreter_executable=self.interpreter_executable) - - #start arguments with command + command, arguments = self.REDIRECT( + self.full_name_of_the_worker, + self.redirect_stdout_file, + self.redirect_stderr_file, + command=self.python_exe_for_redirection, + interpreter_executable=self.interpreter_executable, + ) + + # start arguments with command arguments.insert(0, command) if self.initialize_mpi and len(self.mpiexec) > 0: @@ -1812,73 +2064,92 @@ def generate_command_and_arguments(self,server_address,port): arguments[:0] = mpiexec command = mpiexec[0] - #append with port and hostname where the worker should connect + # append with port and hostname where the worker should connect arguments.append(port) - #hostname of this machine + # hostname of this machine arguments.append(server_address) - - #initialize MPI inside worker executable - arguments.append('true') + + # initialize MPI inside worker executable + arguments.append("true") else: - #append arguments with port and socket where the worker should connect + # append arguments with port and socket where the worker should connect arguments.append(port) - #local machine + # local machine arguments.append(server_address) - - #do not initialize MPI inside worker executable - arguments.append('false') - return command,arguments + # do not initialize MPI inside worker executable + arguments.append("false") + + return command, arguments def remote_env_string(self, hostname): if self.remote_env is None: - if hostname in self.remote_envs.keys(): - return "source "+self.remote_envs[hostname]+"\n" - else: - return "" + if hostname in self.remote_envs.keys(): + return "source " + self.remote_envs[hostname] + "\n" + else: + return "" else: - return "source "+self.remote_env +"\n" + return "source " + self.remote_env + "\n" + + def generate_remote_command_and_arguments(self, hostname, server_address, port): - def generate_remote_command_and_arguments(self,hostname, server_address,port): - # get remote config - args=["ssh","-T", hostname] + args = ["ssh", "-T", hostname] - command=self.remote_env_string(self.hostname)+ \ - "amusifier --get-amuse-config" +"\n" - - proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") - out,err=proc.communicate(command.encode()) + command = ( + self.remote_env_string(self.hostname) + + "amusifier --get-amuse-config" + + "\n" + ) + + proc = Popen(args, stdout=PIPE, stdin=PIPE, executable="ssh") + out, err = proc.communicate(command.encode()) try: - remote_config=parse_configmk_lines(out.decode().split("\n"),"remote config at "+self.hostname ) + remote_config = parse_configmk_lines( + out.decode().split("\n"), "remote config at " + self.hostname + ) except: - raise Exception(f"failed getting remote config from {self.hostname} - please check remote_env argument ({self.remote_env})") + raise Exception( + f"failed getting remote config from {self.hostname} - please check remote_env argument ({self.remote_env})" + ) # get remote amuse package dir - command=self.remote_env_string(self.hostname)+ \ - "amusifier --get-amuse-package-dir" +"\n" - - proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") - out,err=proc.communicate(command.encode()) - - remote_package_dir=out.decode().strip(" \n\t") - local_package_dir=get_amuse_package_dir() - - mpiexec=remote_config["MPIEXEC"] - initialize_mpi=remote_config["MPI_ENABLED"] == 'yes' - run_command_redirected_file=run_command_redirected.__file__.replace(local_package_dir,remote_package_dir) - interpreter_executable=None if self.interpreter_executable==None else remote_config["PYTHON"] + command = ( + self.remote_env_string(self.hostname) + + "amusifier --get-amuse-package-dir" + + "\n" + ) + + proc = Popen(args, stdout=PIPE, stdin=PIPE, executable="ssh") + out, err = proc.communicate(command.encode()) + + remote_package_dir = out.decode().strip(" \n\t") + local_package_dir = get_amuse_package_dir() + + mpiexec = remote_config["MPIEXEC"] + initialize_mpi = remote_config["MPI_ENABLED"] == "yes" + run_command_redirected_file = run_command_redirected.__file__.replace( + local_package_dir, remote_package_dir + ) + interpreter_executable = ( + None if self.interpreter_executable == None else remote_config["PYTHON"] + ) # dynamic python workers? (should be send over) - full_name_of_the_worker=self.full_name_of_the_worker.replace(local_package_dir,remote_package_dir) - python_exe_for_redirection=remote_config["PYTHON"] + full_name_of_the_worker = self.full_name_of_the_worker.replace( + local_package_dir, remote_package_dir + ) + python_exe_for_redirection = remote_config["PYTHON"] if not self.debugger_method is None: raise Exception("remote socket channel debugging not yet supported") - #command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) + # command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) else: - if self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none': - + if ( + self.redirect_stdout_file == "none" + and self.redirect_stderr_file == "none" + ): + if interpreter_executable is None: command = full_name_of_the_worker arguments = [] @@ -1886,12 +2157,16 @@ def generate_remote_command_and_arguments(self,hostname, server_address,port): command = interpreter_executable arguments = [full_name_of_the_worker] else: - command, arguments = self.REDIRECT(full_name_of_the_worker, self.redirect_stdout_file, - self.redirect_stderr_file, command=python_exe_for_redirection, - interpreter_executable=interpreter_executable, - run_command_redirected_file=run_command_redirected_file) - - #start arguments with command + command, arguments = self.REDIRECT( + full_name_of_the_worker, + self.redirect_stdout_file, + self.redirect_stderr_file, + command=python_exe_for_redirection, + interpreter_executable=interpreter_executable, + run_command_redirected_file=run_command_redirected_file, + ) + + # start arguments with command arguments.insert(0, command) if initialize_mpi and len(mpiexec) > 0: @@ -1902,75 +2177,108 @@ def generate_remote_command_and_arguments(self,hostname, server_address,port): arguments[:0] = mpiexec command = mpiexec[0] - #append with port and hostname where the worker should connect + # append with port and hostname where the worker should connect arguments.append(port) - #hostname of this machine + # hostname of this machine arguments.append(server_address) - - #initialize MPI inside worker executable - arguments.append('true') + + # initialize MPI inside worker executable + arguments.append("true") else: - #append arguments with port and socket where the worker should connect + # append arguments with port and socket where the worker should connect arguments.append(port) - #local machine + # local machine arguments.append(server_address) - - #do not initialize MPI inside worker executable - arguments.append('false') - return command,arguments + # do not initialize MPI inside worker executable + arguments.append("false") + + return command, arguments def start(self): - + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - server_address=self.get_host_ip(self.hostname) - - server_socket.bind((server_address , 0)) + + server_address = self.get_host_ip(self.hostname) + + server_socket.bind((server_address, 0)) server_socket.settimeout(1.0) server_socket.listen(1) - - logger.debug("starting socket worker process, listening for worker connection on %s", server_socket.getsockname()) - #this option set by CodeInterface + logger.debug( + "starting socket worker process, listening for worker connection on %s", + server_socket.getsockname(), + ) + + # this option set by CodeInterface logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) - - # set arguments to name of the worker, and port number we listen on + + # set arguments to name of the worker, and port number we listen on self.stdout = None self.stderr = None - + if self.remote: - command,arguments=self.generate_remote_command_and_arguments(self.hostname,server_address,str(server_socket.getsockname()[1])) + command, arguments = self.generate_remote_command_and_arguments( + self.hostname, server_address, str(server_socket.getsockname()[1]) + ) else: - command,arguments=self.generate_command_and_arguments(server_address,str(server_socket.getsockname()[1])) - + command, arguments = self.generate_command_and_arguments( + server_address, str(server_socket.getsockname()[1]) + ) + if self.remote: - logger.debug("starting remote process on %s with command `%s`, arguments `%s` and environment '%s'", self.hostname, command, arguments, os.environ) - ssh_command=self.remote_env_string(self.hostname)+" ".join(arguments) - arguments=["ssh","-T", self.hostname] - command="ssh" - self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) - self.process.stdin.write(ssh_command.encode()) - self.process.stdin.close() + logger.debug( + "starting remote process on %s with command `%s`, arguments `%s` and environment '%s'", + self.hostname, + command, + arguments, + os.environ, + ) + ssh_command = self.remote_env_string(self.hostname) + " ".join(arguments) + arguments = ["ssh", "-T", self.hostname] + command = "ssh" + self.process = Popen( + arguments, + executable=command, + stdin=PIPE, + stdout=None, + stderr=None, + close_fds=self.close_fds, + ) + self.process.stdin.write(ssh_command.encode()) + self.process.stdin.close() else: - logger.debug("starting process with command `%s`, arguments `%s` and environment '%s'", command, arguments, os.environ) - # ~ print(arguments) - self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) + logger.debug( + "starting process with command `%s`, arguments `%s` and environment '%s'", + command, + arguments, + os.environ, + ) + # ~ print(arguments) + self.process = Popen( + arguments, + executable=command, + stdin=PIPE, + stdout=None, + stderr=None, + close_fds=self.close_fds, + ) logger.debug("waiting for connection from worker") - self.socket, address = self.accept_worker_connection(server_socket, self.process) - + self.socket, address = self.accept_worker_connection( + server_socket, self.process + ) + self.socket.setblocking(1) - + self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - + server_socket.close() - + # logger.debug("got connection from %s", address) - + # logger.info("worker %s initialized", self.name_of_the_worker) - @option(type="boolean", sections=("sockets_channel",)) def close_fds(self): @@ -1979,52 +2287,52 @@ def close_fds(self): @option(type="dict", sections=("sockets_channel",)) def remote_envs(self): - """ dict of remote machine - enviroment (source ..) pairs """ + """dict of remote machine - enviroment (source ..) pairs""" return dict() @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" return "none" - + @option(sections=("channel",)) def hostname(self): return None - + def stop(self): - if (self.socket == None): + if self.socket == None: return - + logger.debug("stopping socket worker %s", self.name_of_the_worker) self.socket.close() - + self.socket = None if not self.process.stdin is None: self.process.stdin.close() - + # should lookinto using poll with a timeout or some other mechanism # when debugger method is on, no killing count = 0 - while(count < 5): + while count < 5: returncode = self.process.poll() if not returncode is None: break time.sleep(0.2) count += 1 - + if not self.stdout is None: self.stdout.close() - + if not self.stderr is None: self.stderr.close() def is_active(self): return self.socket is not None - + def is_inuse(self): return self._is_inuse - + def determine_length_from_datax(self, dtype_to_arguments): def get_length(type_and_values): argument_type, argument_values = type_and_values @@ -2037,98 +2345,136 @@ def get_length(type_and_values): except: result = max(result, 1) return result - - - + lengths = [get_length(x) for x in dtype_to_arguments.items()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = ()): - + + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=() + ): + call_count = self.determine_length_from_data(dtype_to_arguments) - + # logger.info("sending message for call id %d, function %d, length %d", id, tag, length) - + if self.is_inuse(): - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) if self.socket is None: - raise exceptions.CodeException("You've tried to send a message to a code that is not running") - - + raise exceptions.CodeException( + "You've tried to send a message to a code that is not running" + ) + if call_count > self.max_message_length: - self.split_message(call_id, function_id, call_count, dtype_to_arguments, encoded_units) + self.split_message( + call_id, function_id, call_count, dtype_to_arguments, encoded_units + ) else: - message = SocketMessage(call_id, function_id, call_count, dtype_to_arguments, encoded_units = encoded_units) + message = SocketMessage( + call_id, + function_id, + call_count, + dtype_to_arguments, + encoded_units=encoded_units, + ) message.send(self.socket) self._is_inuse = True def recv_message(self, call_id, function_id, handle_as_array, has_units=False): - + self._is_inuse = False - + if self._communicated_splitted_message: x = self._merged_results_splitted_message self._communicated_splitted_message = False del self._merged_results_splitted_message return x - + message = SocketMessage() - + message.receive(self.socket) if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] if len(message.strings) > 0 else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." + self.stop() + error_message += " - code probably died, sorry." raise exceptions.CodeException("Error in code: " + error_message) if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units=False): + def nonblocking_recv_message( + self, call_id, function_id, handle_as_array, has_units=False + ): request = SocketMessage().nonblocking_receive(self.socket) - + def handle_result(function): self._is_inuse = False - + message = function() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] + if len(message.strings) > 0 + else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." - raise exceptions.CodeException("Error in (asynchronous) communication with worker: " + error_message) - + self.stop() + error_message += " - code probably died, sorry." + raise exceptions.CodeException( + "Error in (asynchronous) communication with worker: " + + error_message + ) + if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) - + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) + if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) request.add_result_handler(handle_result) - + return request @option(type="int", sections=("channel",)) @@ -2136,241 +2482,290 @@ def max_message_length(self): """ For calls to functions that can handle arrays, MPI messages may get too long for large N. The MPI channel will split long messages into blocks of size max_message_length. - """ + """ return 1000000 - def sanitize_host(self,hostname): + def sanitize_host(self, hostname): if "@" in hostname: - return hostname.split("@")[1] + return hostname.split("@")[1] return hostname - + def get_host_ip(self, client): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((self.sanitize_host(client), 80)) - ip=s.getsockname()[0] + ip = s.getsockname()[0] s.close() return ip - def makedirs(self,directory): - if self.remote: - args=["ssh","-T", self.hostname] - command=f"mkdir -p {directory}\n" - proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") - out,err=proc.communicate(command.encode()) - else: - os.makedirs(directory) + def makedirs(self, directory): + if self.remote: + args = ["ssh", "-T", self.hostname] + command = f"mkdir -p {directory}\n" + proc = Popen(args, stdout=PIPE, stdin=PIPE, executable="ssh") + out, err = proc.communicate(command.encode()) + else: + os.makedirs(directory) class OutputHandler(threading.Thread): - + def __init__(self, stream, port): threading.Thread.__init__(self) self.stream = stream logger.debug("output handler connecting to daemon at %d", port) - + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - address = ('localhost', port) - + + address = ("localhost", port) + try: self.socket.connect(address) except: - raise exceptions.CodeException("Could not connect to Distributed Daemon at " + str(address)) - + raise exceptions.CodeException( + "Could not connect to Distributed Daemon at " + str(address) + ) + self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - - self.socket.sendall('TYPE_OUTPUT'.encode('utf-8')) + + self.socket.sendall("TYPE_OUTPUT".encode("utf-8")) # fetch ID of this connection - + result = SocketMessage() result.receive(self.socket) - + self.id = result.strings[0] - + logger.debug("output handler successfully connected to daemon at %d", port) self.daemon = True self.start() - + def run(self): - + while True: # logger.debug("receiving data for output") data = self.socket.recv(1024) - + if len(data) == 0: # logger.debug("end of output", len(data)) return - + # logger.debug("got %d bytes", len(data)) - + self.stream.write(data) class DistributedChannel(AbstractMessageChannel): - + default_distributed_instance = None - + @staticmethod def getStdoutID(instance): if not hasattr(instance, "_stdoutHandler") or instance._stdoutHandler is None: instance._stdoutHandler = OutputHandler(sys.stdout, instance.port) - + return instance._stdoutHandler.id - + @staticmethod def getStderrID(instance): if not hasattr(instance, "_stderrHandler") or instance._stderrHandler is None: instance._stderrHandler = OutputHandler(sys.stderr, instance.port) - + return instance._stderrHandler.id - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, - distributed_instance=None, dynamic_python_code=False, **options): + + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + distributed_instance=None, + dynamic_python_code=False, + **options, + ): AbstractMessageChannel.__init__(self, **options) - + self._is_inuse = False self._communicated_splitted_message = False - + if distributed_instance is None: if self.default_distributed_instance is None: - raise Exception("No default distributed instance present, and none explicitly passed to code") + raise Exception( + "No default distributed instance present, and none explicitly passed to code" + ) self.distributed_instance = self.default_distributed_instance else: self.distributed_instance = distributed_instance - - #logger.setLevel(logging.DEBUG) - + + # logger.setLevel(logging.DEBUG) + logger.info("initializing DistributedChannel with options %s", options) - - self.socket=None - + + self.socket = None + self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + self.dynamic_python_code = dynamic_python_code - + if self.number_of_workers == 0: self.number_of_workers = 1 - + if self.label == None: self.label = "" - - logger.debug("number of workers is %d, number of threads is %s, label is %s", self.number_of_workers, self.number_of_threads, self.label) - - self.daemon_host = 'localhost' # Distributed process always running on the local machine - self.daemon_port = self.distributed_instance.port # Port number for the Distributed process + + logger.debug( + "number of workers is %d, number of threads is %s, label is %s", + self.number_of_workers, + self.number_of_threads, + self.label, + ) + + self.daemon_host = ( + "localhost" # Distributed process always running on the local machine + ) + self.daemon_port = ( + self.distributed_instance.port + ) # Port number for the Distributed process logger.debug("port is %d", self.daemon_port) - + self.id = 0 - + if not legacy_interface_type is None: # worker specified by type. Figure out where this file is # mostly (only?) used by dynamic python codes - directory_of_this_module = os.path.dirname(inspect.getfile(legacy_interface_type)) - worker_path = os.path.join(directory_of_this_module, self.name_of_the_worker) - self.full_name_of_the_worker = os.path.normpath(os.path.abspath(worker_path)) - + directory_of_this_module = os.path.dirname( + inspect.getfile(legacy_interface_type) + ) + worker_path = os.path.join( + directory_of_this_module, self.name_of_the_worker + ) + self.full_name_of_the_worker = os.path.normpath( + os.path.abspath(worker_path) + ) + self.name_of_the_worker = os.path.basename(self.full_name_of_the_worker) - + else: # worker specified by executable (usually already absolute) - self.full_name_of_the_worker = os.path.normpath(os.path.abspath(self.name_of_the_worker)) - + self.full_name_of_the_worker = os.path.normpath( + os.path.abspath(self.name_of_the_worker) + ) + global_options = GlobalOptions() - - self.executable = os.path.relpath(self.full_name_of_the_worker, global_options.amuse_rootdirectory) - + + self.executable = os.path.relpath( + self.full_name_of_the_worker, global_options.amuse_rootdirectory + ) + self.worker_dir = os.path.dirname(self.full_name_of_the_worker) - + logger.debug("executable is %s", self.executable) logger.debug("full name of the worker is %s", self.full_name_of_the_worker) - + logger.debug("worker dir is %s", self.worker_dir) - + self._is_inuse = False def check_if_worker_is_up_to_date(self, object): -# if self.hostname != 'localhost': -# return -# -# logger.debug("hostname = %s, checking for worker", self.hostname) -# -# AbstractMessageChannel.check_if_worker_is_up_to_date(self, object) - + # if self.hostname != 'localhost': + # return + # + # logger.debug("hostname = %s, checking for worker", self.hostname) + # + # AbstractMessageChannel.check_if_worker_is_up_to_date(self, object) + pass - + def start(self): logger.debug("connecting to daemon") - + # if redirect = none, set output file to console stdout stream ID, otherwise make absolute - if (self.redirect_stdout_file == 'none'): + if self.redirect_stdout_file == "none": self.redirect_stdout_file = self.getStdoutID(self.distributed_instance) else: self.redirect_stdout_file = os.path.abspath(self.redirect_stdout_file) # if redirect = none, set error file to console stderr stream ID, otherwise make absolute - if (self.redirect_stderr_file == 'none'): + if self.redirect_stderr_file == "none": self.redirect_stderr_file = self.getStderrID(self.distributed_instance) else: self.redirect_stderr_file = os.path.abspath(self.redirect_stderr_file) - + logger.debug("output send to = " + self.redirect_stdout_file) - + logger.debug("error send to = " + self.redirect_stderr_file) - + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: self.socket.connect((self.daemon_host, self.daemon_port)) except: self.socket = None - raise exceptions.CodeException("Could not connect to Ibis Daemon at " + str(self.daemon_port)) - + raise exceptions.CodeException( + "Could not connect to Ibis Daemon at " + str(self.daemon_port) + ) + self.socket.setblocking(1) - + self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - - self.socket.sendall('TYPE_WORKER'.encode('utf-8')) - - arguments = {'string': [self.executable, self.redirect_stdout_file, self.redirect_stderr_file, self.label, self.worker_dir], 'int32': [self.number_of_workers, self.number_of_threads], 'bool': [ self.dynamic_python_code]} - - message = SocketMessage(call_id=1, function_id=10101010, call_count=1, dtype_to_arguments=arguments) + + self.socket.sendall("TYPE_WORKER".encode("utf-8")) + + arguments = { + "string": [ + self.executable, + self.redirect_stdout_file, + self.redirect_stderr_file, + self.label, + self.worker_dir, + ], + "int32": [self.number_of_workers, self.number_of_threads], + "bool": [self.dynamic_python_code], + } + + message = SocketMessage( + call_id=1, function_id=10101010, call_count=1, dtype_to_arguments=arguments + ) message.send(self.socket) - + logger.info("waiting for worker %s to be initialized", self.name_of_the_worker) result = SocketMessage() result.receive(self.socket) - + if result.error: logger.error("Could not start worker: %s", result.strings[0]) self.stop() - raise exceptions.CodeException("Could not start worker for " + self.name_of_the_worker + ": " + result.strings[0]) - + raise exceptions.CodeException( + "Could not start worker for " + + self.name_of_the_worker + + ": " + + result.strings[0] + ) + self.remote_amuse_dir = result.strings[0] - + logger.info("worker %s initialized", self.name_of_the_worker) logger.info("worker remote amuse dir = %s", self.remote_amuse_dir) - + @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" return "none" - + def get_amuse_root_directory(self): return self.remote_amuse_dir - + @option(type="int", sections=("channel",)) def number_of_threads(self): return 0 - + @option(type="string", sections=("channel",)) def label(self): return None - + def stop(self): if self.socket is not None: logger.info("stopping worker %s", self.name_of_the_worker) @@ -2378,11 +2773,11 @@ def stop(self): self.socket = None def is_active(self): - return self.socket is not None - + return self.socket is not None + def is_inuse(self): return self._is_inuse - + def determine_length_from_datax(self, dtype_to_arguments): def get_length(x): if x: @@ -2391,181 +2786,218 @@ def get_length(x): return len(x[0]) except: return 1 - - - + lengths = [get_length(x) for x in dtype_to_arguments.values()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = None): - + + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=None + ): + call_count = self.determine_length_from_data(dtype_to_arguments) - - logger.debug("sending message for call id %d, function %d, length %d", call_id, function_id, call_count) - + + logger.debug( + "sending message for call id %d, function %d, length %d", + call_id, + function_id, + call_count, + ) + if self.is_inuse(): - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) if self.socket is None: - raise exceptions.CodeException("You've tried to send a message to a code that is not running") - + raise exceptions.CodeException( + "You've tried to send a message to a code that is not running" + ) + if call_count > self.max_message_length: - self.split_message(call_id, function_id, call_count, dtype_to_arguments, encoded_units) + self.split_message( + call_id, function_id, call_count, dtype_to_arguments, encoded_units + ) else: - message = SocketMessage(call_id, function_id, call_count, dtype_to_arguments, False, False) + message = SocketMessage( + call_id, function_id, call_count, dtype_to_arguments, False, False + ) message.send(self.socket) self._is_inuse = True - def recv_message(self, call_id, function_id, handle_as_array, has_units=False): - + self._is_inuse = False - + if self._communicated_splitted_message: x = self._merged_results_splitted_message self._communicated_splitted_message = False del self._merged_results_splitted_message return x - + message = SocketMessage() - + message.receive(self.socket) if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] if len(message.strings) > 0 else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - #~ self.stop() - error_message+=" - code probably died, sorry." + # ~ self.stop() + error_message += " - code probably died, sorry." raise exceptions.CodeException("Error in worker: " + error_message) if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units=False): + def nonblocking_recv_message( + self, call_id, function_id, handle_as_array, has_units=False + ): # raise exceptions.CodeException("Nonblocking receive not supported by DistributedChannel") request = SocketMessage().nonblocking_receive(self.socket) - + def handle_result(function): self._is_inuse = False - + message = function() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] + if len(message.strings) > 0 + else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." - raise exceptions.CodeException("Error in (asynchronous) communication with worker: " + error_message) - + self.stop() + error_message += " - code probably died, sorry." + raise exceptions.CodeException( + "Error in (asynchronous) communication with worker: " + + error_message + ) + if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) - + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) + if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) request.add_result_handler(handle_result) - + return request - + @option(type="int", sections=("channel",)) def max_message_length(self): """ For calls to functions that can handle arrays, MPI messages may get too long for large N. The MPI channel will split long messages into blocks of size max_message_length. - """ + """ return 1000000 + class LocalChannel(AbstractMessageChannel): - - - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, - distributed_instance=None, dynamic_python_code=False, **options): + + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + distributed_instance=None, + dynamic_python_code=False, + **options, + ): AbstractMessageChannel.__init__(self, **options) MpiChannel.ensure_mpi_initialized() if not legacy_interface_type is None: self.so_module = legacy_interface_type.__so_module__ - self.package, _ = legacy_interface_type.__module__.rsplit('.',1) + self.package, _ = legacy_interface_type.__module__.rsplit(".", 1) else: - raise Exception("Need to give the legacy interface type for the local channel") - + raise Exception( + "Need to give the legacy interface type for the local channel" + ) + self.legacy_interface_type = legacy_interface_type self._is_inuse = False self.module = None - - - def check_if_worker_is_up_to_date(self, object): pass - + def start(self): from . import import_module from . import python_code - + module = import_module.import_unique(self.package + "." + self.so_module) print(module, self.package + "." + self.so_module) module.set_comm_world(MPI.COMM_SELF) - self.local_implementation = python_code.CythonImplementation(module, self.legacy_interface_type) + self.local_implementation = python_code.CythonImplementation( + module, self.legacy_interface_type + ) self.module = module - - def stop(self): from . import import_module + import_module.cleanup_module(self.module) self.module = None - def is_active(self): return not self.module is None - + def is_inuse(self): return self._is_inuse - - - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = None): - + + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=None + ): + call_count = self.determine_length_from_data(dtype_to_arguments) - - self.message = LocalMessage(call_id, function_id, call_count, dtype_to_arguments, encoded_units = encoded_units) - self.is_inuse = True - + self.message = LocalMessage( + call_id, + function_id, + call_count, + dtype_to_arguments, + encoded_units=encoded_units, + ) + self.is_inuse = True def recv_message(self, call_id, function_id, handle_as_array, has_units=False): output_message = LocalMessage(call_id, function_id, self.message.call_count) self.local_implementation.handle_message(self.message, output_message) - + if has_units: - return output_message.to_result(handle_as_array),output_message.encoded_units + return ( + output_message.to_result(handle_as_array), + output_message.encoded_units, + ) else: return output_message.to_result(handle_as_array) - - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array): pass - def determine_length_from_datax(self, dtype_to_arguments): def get_length(x): if x: @@ -2575,19 +3007,16 @@ def get_length(x): except: return 1 return 1 - - - + lengths = [get_length(x) for x in dtype_to_arguments.values()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - - def is_polling_supported(self): return False + class LocalMessage(AbstractMessage): pass diff --git a/src/amuse/rfi/gencode.py b/src/amuse/rfi/gencode.py index 4f2f10adae..cc82d6c2f7 100755 --- a/src/amuse/rfi/gencode.py +++ b/src/amuse/rfi/gencode.py @@ -39,13 +39,13 @@ def get_amuse_directory(): return os.path.abspath(directory_of_this_script) # in case of trouble consult old python 2: - #~ def get_amuse_directory(): - #~ filename_of_this_script = __file__ - #~ directory_of_this_script = os.path.dirname(os.path.dirname(filename_of_this_script)) - #~ if os.path.isabs(directory_of_this_script): - #~ return directory_of_this_script - #~ else: - #~ return os.path.abspath(directory_of_this_script) + # def get_amuse_directory(): + # filename_of_this_script = __file__ + # directory_of_this_script = os.path.dirname(os.path.dirname(filename_of_this_script)) + # if os.path.isabs(directory_of_this_script): + # return directory_of_this_script + # else: + # return os.path.abspath(directory_of_this_script) def setup_sys_path(): amuse_directory = os.environ["AMUSE_DIR"] diff --git a/src/amuse/rfi/python_code.py b/src/amuse/rfi/python_code.py index 706a7ee41a..e28aee20f1 100644 --- a/src/amuse/rfi/python_code.py +++ b/src/amuse/rfi/python_code.py @@ -495,20 +495,20 @@ def must_disconnect(self): def internal__become_code(self, number_of_workers, modulename, classname): warnings.warn(" possible experimental code path?") - #~ print number_of_workers, modulename, classname + # print number_of_workers, modulename, classname world = self.freeworld color = 0 if world.rank < number_of_workers else 1 key = world.rank if world.rank < number_of_workers else world.rank - number_of_workers - #~ print "CC,", color, key, world.rank, world.size + # print "CC,", color, key, world.rank, world.size newcomm = world.Split(color, key) - #~ print ("nc:", newcomm.size, newcomm.rank) - #~ print ("AA", self.world, color, self.world.rank, self.world.size) + # print ("nc:", newcomm.size, newcomm.rank) + # print ("AA", self.world, color, self.world.rank, self.world.size) try: new_intercomm = newcomm.Create_intercomm(0, self.world, 0, color) except Exception as ex: warnings.warn(str(ex)) raise ex - #~ print ("nccc:", new_intercomm.Get_remote_size(), new_intercomm.rank) + # print ("nccc:", new_intercomm.Get_remote_size(), new_intercomm.rank) self.communicators.append(new_intercomm) self.id_to_activate = len(self.communicators) - 1 From 389515b20761232edc1ad19304d07dd75cc53815 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Wed, 9 Oct 2024 16:45:08 +0200 Subject: [PATCH 07/12] syntax updates to amuse.rfi --- .../community/fractalcluster/interface.py | 274 ++++--- src/amuse/community/galactics/Makefile | 2 +- .../community/galactics/src/src/genhalo.c | 2 +- src/amuse/datamodel/base.py | 148 ++-- src/amuse/datamodel/particles.py | 23 +- src/amuse/io/base.py | 95 ++- src/amuse/io/store_v2.py | 2 +- src/amuse/rfi/async_request.py | 467 ++++++------ src/amuse/rfi/channel.py | 195 ++--- src/amuse/rfi/gencode.py | 10 +- src/amuse/rfi/import_module.py | 90 ++- src/amuse/rfi/nospawn.py | 103 +-- src/amuse/rfi/python_code.py | 424 ++++++----- src/amuse/rfi/run_command_redirected.py | 47 +- src/amuse/rfi/slurm.py | 48 +- src/amuse/rfi/tools/create_c.py | 445 ++++++----- src/amuse/rfi/tools/create_code.py | 67 +- src/amuse/rfi/tools/create_definition.py | 275 +++---- src/amuse/rfi/tools/create_dir.py | 154 ++-- src/amuse/rfi/tools/create_fortran.py | 702 ++++++++++-------- src/amuse/rfi/tools/create_java.py | 437 ++++++----- src/amuse/rfi/tools/create_python_worker.py | 58 +- src/amuse/rfi/tools/fortran_tools.py | 454 ++++++----- 23 files changed, 2539 insertions(+), 1983 deletions(-) diff --git a/src/amuse/community/fractalcluster/interface.py b/src/amuse/community/fractalcluster/interface.py index 8f150d140b..12022ad116 100644 --- a/src/amuse/community/fractalcluster/interface.py +++ b/src/amuse/community/fractalcluster/interface.py @@ -1,197 +1,247 @@ import numpy import os.path -from amuse.community import * +from amuse.community import ( + CodeInterface, LiteratureReferencesMixIn, + legacy_function, LegacyFunctionSpecification, + InCodeComponentImplementation, + exceptions, +) from amuse.datamodel import Particles from amuse import datamodel from amuse.units import nbody_system -from amuse.community.interface.common import CommonCodeInterface, CommonCode +from amuse.community.interface.common import CommonCode -class FractalClusterInterface(CodeInterface, LiteratureReferencesMixIn): + +class FractalClusterInterface(CodeInterface, LiteratureReferencesMixIn): """ - makes fractal of nstar particles of dimension fdim, using ndiv + Makes a fractal distribution of N particles of dimension fdim, using ndiv subunits forcing the number of cells if force=.true. - - reference: + + Reference: .. [#] ADS:2004A&A...413..929G (Simon Goodwin & Ant Whitworth (2004, A&A, 413, 929)) """ def __init__(self, **options): - CodeInterface.__init__(self, name_of_the_worker = self.name_of_the_worker(), **options) + CodeInterface.__init__( + self, name_of_the_worker=self.name_of_the_worker(), **options + ) LiteratureReferencesMixIn.__init__(self) def name_of_the_worker(self): - return 'fractal_worker' + return "fractal_worker" @legacy_function def get_state(): function = LegacyFunctionSpecification() function.must_handle_array = True - function.addParameter('id', dtype='i', direction=function.IN) - for x in ['x','y','z','vx','vy','vz']: - function.addParameter(x, dtype='d', direction=function.OUT) - function.addParameter('nstar', dtype='i', direction=function.LENGTH) - function.result_type = 'i' + function.addParameter("id", dtype="i", direction=function.IN) + for x in ["x", "y", "z", "vx", "vy", "vz"]: + function.addParameter(x, dtype="d", direction=function.OUT) + function.addParameter("nstar", dtype="i", direction=function.LENGTH) + function.result_type = "i" return function @legacy_function def generate_particles(): function = LegacyFunctionSpecification() - function.result_type = 'i' + function.result_type = "i" return function @legacy_function def get_fractal_dimension(): function = LegacyFunctionSpecification() - function.addParameter('fdim', dtype='d', direction=function.OUT) - function.result_type = 'i' + function.addParameter("fdim", dtype="d", direction=function.OUT) + function.result_type = "i" return function + @legacy_function def set_fractal_dimension(): function = LegacyFunctionSpecification() - function.addParameter('fdim', dtype='d', direction=function.IN) - function.result_type = 'i' + function.addParameter("fdim", dtype="d", direction=function.IN) + function.result_type = "i" return function @legacy_function def get_random_seed(): function = LegacyFunctionSpecification() - function.addParameter('seed', dtype='i', direction=function.OUT) - function.result_type = 'i' + function.addParameter("seed", dtype="i", direction=function.OUT) + function.result_type = "i" return function + @legacy_function def set_random_seed(): function = LegacyFunctionSpecification() - function.addParameter('seed', dtype='i', direction=function.IN) - function.result_type = 'i' + function.addParameter("seed", dtype="i", direction=function.IN) + function.result_type = "i" return function @legacy_function def get_nstar(): function = LegacyFunctionSpecification() - function.addParameter('nstar', dtype='i', direction=function.OUT) - function.result_type = 'i' + function.addParameter("nstar", dtype="i", direction=function.OUT) + function.result_type = "i" return function + @legacy_function def set_nstar(): function = LegacyFunctionSpecification() - function.addParameter('nstar', dtype='i', direction=function.IN) - function.result_type = 'i' + function.addParameter("nstar", dtype="i", direction=function.IN) + function.result_type = "i" return function - + new_particle = None - + def delete_particle(self, index_of_the_particle): return 0 - + @legacy_function def get_number_of_particles_updated(): """ Return the number of particles added during the last generate_particles. """ function = LegacyFunctionSpecification() - function.addParameter('number_of_particles', dtype='int32', direction=function.OUT) - function.result_type = 'int32' + function.addParameter( + "number_of_particles", dtype="int32", direction=function.OUT + ) + function.result_type = "int32" return function - class FractalCluster(CommonCode): - + def __init__(self, unit_converter=None, **options): self.unit_converter = unit_converter - InCodeComponentImplementation.__init__(self, FractalClusterInterface(**options), **options) - + InCodeComponentImplementation.__init__( + self, FractalClusterInterface(**options), **options + ) + def initialize_code(self): pass - + def cleanup_code(self): pass - + def commit_parameters(self): pass - + def recommit_parameters(self): pass - + def define_parameters(self, handler): handler.add_method_parameter( "get_nstar", "set_nstar", "number_of_particles", "the number of particles to be generated in the model", - default_value = 0 + default_value=0, ) - + handler.add_method_parameter( "get_fractal_dimension", "set_fractal_dimension", "fractal_dimension", "the fractal dimension of the spatial particle distribution", - default_value = 1.6 + default_value=1.6, ) - + handler.add_method_parameter( "get_random_seed", "set_random_seed", "random_seed", "the initial seed to be used by the random number generator", - default_value = 1234321 + default_value=1234321, ) - + def define_methods(self, handler): CommonCode.define_methods(self, handler) handler.add_method("generate_particles", (), (handler.ERROR_CODE,)) - handler.add_method("get_number_of_particles_updated", (), (handler.NO_UNIT, handler.ERROR_CODE,)) - - handler.add_method("get_state", (handler.INDEX,), - [nbody_system.length]*3 + [nbody_system.speed]*3 + [handler.ERROR_CODE] + handler.add_method( + "get_number_of_particles_updated", + (), + ( + handler.NO_UNIT, + handler.ERROR_CODE, + ), + ) + + handler.add_method( + "get_state", + (handler.INDEX,), + [nbody_system.length] * 3 + [nbody_system.speed] * 3 + [handler.ERROR_CODE], + ) + + handler.add_method( + "get_target_number_of_particles", + (), + ( + handler.NO_UNIT, + handler.ERROR_CODE, + ), + ) + handler.add_method( + "set_target_number_of_particles", (handler.NO_UNIT,), (handler.ERROR_CODE,) + ) + + handler.add_method( + "get_fractal_dimension", + (), + ( + handler.NO_UNIT, + handler.ERROR_CODE, + ), + ) + handler.add_method( + "set_fractal_dimension", (handler.NO_UNIT,), (handler.ERROR_CODE,) ) - - handler.add_method("get_target_number_of_particles", (), (handler.NO_UNIT, handler.ERROR_CODE,)) - handler.add_method("set_target_number_of_particles", (handler.NO_UNIT, ), (handler.ERROR_CODE,)) - - handler.add_method("get_fractal_dimension", (), (handler.NO_UNIT, handler.ERROR_CODE,)) - handler.add_method("set_fractal_dimension", (handler.NO_UNIT, ), (handler.ERROR_CODE,)) - - handler.add_method("get_random_seed", (), (handler.NO_UNIT, handler.ERROR_CODE,)) - handler.add_method("set_random_seed", (handler.NO_UNIT, ), (handler.ERROR_CODE,)) + + handler.add_method( + "get_random_seed", + (), + ( + handler.NO_UNIT, + handler.ERROR_CODE, + ), + ) + handler.add_method("set_random_seed", (handler.NO_UNIT,), (handler.ERROR_CODE,)) def define_converter(self, handler): - if not self.unit_converter is None: + if self.unit_converter is not None: handler.set_converter(self.unit_converter.as_converter_from_si_to_generic()) - + def define_particle_sets(self, handler): - handler.define_set('particles', 'index_of_the_particle') - handler.set_new('particles', 'new_particle') - handler.set_delete('particles', 'delete_particle') - handler.add_getter('particles', 'get_state') - + handler.define_set("particles", "index_of_the_particle") + handler.set_new("particles", "new_particle") + handler.set_delete("particles", "delete_particle") + handler.add_getter("particles", "get_state") + def define_state(self, handler): CommonCode.define_state(self, handler) - handler.add_transition('INITIALIZED','EDIT','commit_parameters') - handler.add_transition('EDIT','CHANGE_PARAMETERS_EDIT','before_set_parameter', False) - handler.add_transition('CHANGE_PARAMETERS_EDIT','EDIT','recommit_parameters') - - handler.add_method('CHANGE_PARAMETERS_EDIT', 'before_set_parameter') - - handler.add_method('CHANGE_PARAMETERS_EDIT', 'before_get_parameter') - handler.add_method('RUN', 'before_get_parameter') - handler.add_method('EDIT', 'before_get_parameter') - - handler.add_transition('EDIT', 'RUN', 'generate_particles', False) - handler.add_transition('RUN', 'EDIT', 'clear_particle_set') - handler.add_method('EDIT', 'get_number_of_particles_updated') - handler.add_method('RUN', 'get_number_of_particles_updated') - handler.add_method('RUN', 'get_state') - + handler.add_transition("INITIALIZED", "EDIT", "commit_parameters") + handler.add_transition( + "EDIT", "CHANGE_PARAMETERS_EDIT", "before_set_parameter", False + ) + handler.add_transition("CHANGE_PARAMETERS_EDIT", "EDIT", "recommit_parameters") + + handler.add_method("CHANGE_PARAMETERS_EDIT", "before_set_parameter") + + handler.add_method("CHANGE_PARAMETERS_EDIT", "before_get_parameter") + handler.add_method("RUN", "before_get_parameter") + handler.add_method("EDIT", "before_get_parameter") + + handler.add_transition("EDIT", "RUN", "generate_particles", False) + handler.add_transition("RUN", "EDIT", "clear_particle_set") + handler.add_method("EDIT", "get_number_of_particles_updated") + handler.add_method("RUN", "get_number_of_particles_updated") + handler.add_method("RUN", "get_state") + def generate_particles(self): result = self.overridden().generate_particles() self.update_particle_set() - + def update_particle_set(self): """ update the particle set after changes in the code - + this implementation needs to move to the amuse.datamodel.incode_storage module, as it uses a lot of internal methods and info! @@ -199,38 +249,49 @@ def update_particle_set(self): number_of_updated_particles = self.get_number_of_particles_updated() if number_of_updated_particles: self.particles._private.attribute_storage._add_indices( - list(range(1, number_of_updated_particles+1)) + list(range(1, number_of_updated_particles + 1)) ) - + def clear_particle_set(self): if len(self.particles): self.particles.remove_particles(self.particles) - -class MakeFractalCluster(object): - - def __init__(self, N=None, convert_nbody=None, masses=None, do_scale=True, - random_seed=None, fractal_dimension=1.6, virial_ratio=0.5, verbose=False, match_N=True): +class MakeFractalCluster: + + def __init__( + self, + N=None, + convert_nbody=None, + masses=None, + do_scale=True, + random_seed=None, + fractal_dimension=1.6, + virial_ratio=0.5, + verbose=False, + match_N=True, + ): if masses is None: if N is None: - raise exceptions.AmuseException("Either keyword argument 'N' (number of particles) or " - "'masses' (vector quantity with mass of each particle) is required.") + raise exceptions.AmuseException( + "Either keyword argument 'N' (number of particles) or " + "'masses' (vector quantity with mass of each particle) is required." + ) self.masses = numpy.ones(N) / N | nbody_system.mass self.N = N else: - if not N is None and len(masses) != N: + if N is not None and len(masses) != N: print("warning: provided mass array not equal to masses") self.masses = masses / masses.sum() | nbody_system.mass self.N = len(masses) - - self.convert_nbody=convert_nbody - self.do_scale=do_scale - self.random_seed=random_seed - self.fractal_dimension=fractal_dimension - self.virial_ratio=virial_ratio - self.verbose=verbose - self.match_N=match_N + + self.convert_nbody = convert_nbody + self.do_scale = do_scale + self.random_seed = random_seed + self.fractal_dimension = fractal_dimension + self.virial_ratio = virial_ratio + self.verbose = verbose + self.match_N = match_N def new_model(self): generator = FractalCluster(redirection=("none" if self.verbose else "null")) @@ -240,7 +301,7 @@ def new_model(self): generator.parameters.random_seed = self.random_seed generator.generate_particles() if self.match_N: - while len(generator.particles)>> from amuse.units import units >>> original = Particles(2) >>> original.mass = 0 | units.m - >>> print hasattr(original, "mass") + >>> print(hasattr(original, "mass")) True - >>> print len(original) + >>> print(len(original)) 2 >>> copy = original.empty_copy() - >>> print hasattr(copy, "mass") + >>> print(hasattr(copy, "mass")) False - >>> print len(copy) + >>> print(len(copy)) 2 """ @@ -771,17 +771,16 @@ def remove_particle(self, particle): def synchronize_to(self, other_particles): """ - Synchronize the particles of this set - with the contents of the provided set. + Synchronize the particles of this set with the contents of the provided + set. - After this call the `other_particles` set will have - the same particles as this set. + After this call the `other_particles` set will have the same particles + as this set. - This call will check if particles have been removed or - added it will not copy values of existing particles - over. + This call will check if particles have been removed or added, it will + not copy values of existing particles over. - :parameter other_particles: particle set wich has to be updated + :parameter other_particles: particle set which has to be updated >>> particles = Particles(2) >>> particles.x = [1.0, 2.0] | units.m diff --git a/src/amuse/io/base.py b/src/amuse/io/base.py index deaa7b4757..53b7bf9cce 100644 --- a/src/amuse/io/base.py +++ b/src/amuse/io/base.py @@ -16,19 +16,34 @@ class IoException(exceptions.CoreException): class UnsupportedFormatException(IoException): """Raised when the given format is not supported by AMUSE.""" - formatstring = "You tried to load or save a file with fileformat '{0}', but this format is not in the supported formats list" + formatstring = ( + "You tried to load or save a file with fileformat '{0}', but this format " + "is not in the supported formats list" + ) class CannotSaveException(IoException): - """Raised when the given format cannot save data (only reading of data is supported for the format)""" + """ + Raised when the given format cannot save data (only reading of data is + supported for the format) + """ - formatstring = "You tried to save a file with fileformat '{0}', but this format is not supported for writing files" + formatstring = ( + "You tried to save a file with fileformat '{0}', but this format is not " + "supported for writing files" + ) class CannotLoadException(IoException): - """Raised when the given format cannot read data (only saving of data is supported for the format)""" + """ + Raised when the given format cannot read data (only saving of data is + supported for the format) + """ - formatstring = "You tried to load a file with fileformat '{0}', but this format is not supported for reading files" + formatstring = ( + "You tried to load a file with fileformat '{0}', but this format is not " + "supported for reading files" + ) class format_option(late): @@ -41,19 +56,19 @@ def get_name(self): return self.initializer.__name__ -def _get_processor_factory(format): - if isinstance(format, str): - if not format in registered_fileformat_processors: - raise UnsupportedFormatException(format) - processor_factory = registered_fileformat_processors[format] +def _get_processor_factory(fileformat): + if isinstance(fileformat, str): + if fileformat not in registered_fileformat_processors: + raise UnsupportedFormatException(fileformat) + processor_factory = registered_fileformat_processors[fileformat] else: - processor_factory = format + processor_factory = fileformat return processor_factory def write_set_to_file( - set, filename, format="amuse", **format_specific_keyword_arguments + particleset, filename, format="amuse", **format_specific_keyword_arguments ): """ Write a set to the given file in the given format. @@ -70,7 +85,7 @@ class and not an instance) processor_factory = _get_processor_factory(format) - processor = processor_factory(filename, set=set, format=format) + processor = processor_factory(filename, set=particleset, format=format) processor.set_options(format_specific_keyword_arguments) processor.store() @@ -100,7 +115,7 @@ class and not an instance) return processor.load() -class ReportTable(object): +class ReportTable: """ Report quantities and values to a file. @@ -234,7 +249,7 @@ def _update_documentation_strings(): method.__doc__ = new_doc -class FileFormatProcessor(object): +class FileFormatProcessor: """ Abstract base class of all fileformat processors @@ -445,56 +460,56 @@ def int_type(self): return result.newbyteorder(self.endianness) def read_fortran_block(self, file): - """Returns one block read from file. Checks if the - block is consistant. Result is an array of bytes """ - format = self.endianness + "I" - bytes = file.read(4) - if not bytes: + Returns one block read from file. Checks if the block is consistent. + Result is an array of bytes. + """ + fileformat = self.endianness + "I" + bytesarray = file.read(4) + if not bytesarray: return None - length_of_block = struct.unpack(format, bytes)[0] + length_of_block = struct.unpack(fileformat, bytesarray)[0] result = file.read(length_of_block) - bytes = file.read(4) - length_of_block_after = struct.unpack(format, bytes)[0] + bytesarray = file.read(4) + length_of_block_after = struct.unpack(fileformat, bytesarray)[0] if length_of_block_after != length_of_block: raise IoException( - "Block is mangled sizes don't match before: {0}, after: {1}".format( - length_of_block, length_of_block_after - ) + f"Block is mangled sizes don't match before: {length_of_block}, " + f"after: {length_of_block_after}" ) return result def read_fortran_block_floats(self, file): - bytes = self.read_fortran_block(file) - return numpy.frombuffer(bytes, dtype=self.float_type) + bytesarray = self.read_fortran_block(file) + return numpy.frombuffer(bytesarray, dtype=self.float_type) def read_fortran_block_doubles(self, file): - bytes = self.read_fortran_block(file) - return numpy.frombuffer(bytes, dtype=self.double_type) + bytesarray = self.read_fortran_block(file) + return numpy.frombuffer(bytesarray, dtype=self.double_type) def read_fortran_block_uints(self, file): - bytes = self.read_fortran_block(file) - return numpy.frombuffer(bytes, dtype=self.uint_type) + bytesarray = self.read_fortran_block(file) + return numpy.frombuffer(bytesarray, dtype=self.uint_type) def read_fortran_block_ulongs(self, file): - bytes = self.read_fortran_block(file) - return numpy.frombuffer(bytes, dtype=self.ulong_type) + bytesarray = self.read_fortran_block(file) + return numpy.frombuffer(bytesarray, dtype=self.ulong_type) def read_fortran_block_ints(self, file): - bytes = self.read_fortran_block(file) - return numpy.frombuffer(bytes, dtype=self.int_type) + bytesarray = self.read_fortran_block(file) + return numpy.frombuffer(bytesarray, dtype=self.int_type) def read_fortran_block_float_vectors(self, file, size=3): result = self.read_fortran_block_floats(file) return result.reshape(len(result) // size, size) def write_fortran_block(self, file, input): - format = self.endianness + "I" - input_bytes = bytearray(input) + fileformat = self.endianness + "I" + input_bytes = bytearray(fileformat) length_of_block = len(input_bytes) - file.write(struct.pack(format, length_of_block)) + file.write(struct.pack(fileformat, length_of_block)) file.write(input_bytes) - file.write(struct.pack(format, length_of_block)) + file.write(struct.pack(fileformat, length_of_block)) def write_fortran_block_floats(self, file, values): array = numpy.asarray(values, dtype=self.float_type) diff --git a/src/amuse/io/store_v2.py b/src/amuse/io/store_v2.py index f6e15c4329..db23ff4c7d 100644 --- a/src/amuse/io/store_v2.py +++ b/src/amuse/io/store_v2.py @@ -1216,7 +1216,7 @@ def get_set_from_reference(self, reference): referenced_group = self.derefence(reference) mapping_from_groupid_to_set = self.mapping_from_groupid_to_set - if not referenced_group.id in mapping_from_groupid_to_set: + if referenced_group.id not in mapping_from_groupid_to_set: linked_set = self.load_from_group(referenced_group) else: linked_set = mapping_from_groupid_to_set[referenced_group.id] diff --git a/src/amuse/rfi/async_request.py b/src/amuse/rfi/async_request.py index 707a7447fe..16d0711550 100644 --- a/src/amuse/rfi/async_request.py +++ b/src/amuse/rfi/async_request.py @@ -3,6 +3,7 @@ from . import channel + class AbstractASyncRequest(object): def __bool__(self): return not self.is_finished @@ -40,27 +41,27 @@ def result(self): def results(self): return [self.result()] - def add_result_handler(self, function, args = ()): - self.result_handlers.append([function,args]) + def add_result_handler(self, function, args=()): + self.result_handlers.append([function, args]) def is_mpi_request(self): return False def is_socket_request(self): return False - + def is_other(self): return not self.is_mpi_request() and not self.is_socket_request() - + def get_mpi_request(self): raise Exception("not implemented") def get_socket(self): raise Exception("not implemented") - - #~ def is_pool(self): - #~ return False - + + # def is_pool(self): + # return False + def join(self, other): if other is None: return self @@ -80,74 +81,94 @@ def waits_for(self): return self def __getitem__(self, index): - return IndexedASyncRequest(self,index) + return IndexedASyncRequest(self, index) + + # def __getattr__(self, name): + # print name, "<<" - #~ def __getattr__(self, name): - #~ print name, "<<" - def __add__(self, other): - return baseOperatorASyncRequest(self,other, operator.add) + return baseOperatorASyncRequest(self, other, operator.add) + def __radd__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.add(y,x)) + return baseOperatorASyncRequest(self, other, lambda x, y: operator.add(y, x)) + def __sub__(self, other): - return baseOperatorASyncRequest(self,other, operator.sub) + return baseOperatorASyncRequest(self, other, operator.sub) + def __rsub__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.sub(y,x)) + return baseOperatorASyncRequest(self, other, lambda x, y: operator.sub(y, x)) + def __mul__(self, other): - return baseOperatorASyncRequest(self,other, operator.__mul__) + return baseOperatorASyncRequest(self, other, operator.__mul__) + def __rmul__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.mul(y,x)) + return baseOperatorASyncRequest(self, other, lambda x, y: operator.mul(y, x)) + def __truediv__(self, other): - return baseOperatorASyncRequest(self,other, operator.truediv) + return baseOperatorASyncRequest(self, other, operator.truediv) + def __rtruediv__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.truediv(y,x)) + return baseOperatorASyncRequest( + self, other, lambda x, y: operator.truediv(y, x) + ) + def __floordiv__(self, other): - return baseOperatorASyncRequest(self,other, operator.floordiv) + return baseOperatorASyncRequest(self, other, operator.floordiv) + def __rfloordiv__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.floordiv(y,x)) + return baseOperatorASyncRequest( + self, other, lambda x, y: operator.floordiv(y, x) + ) + def __div__(self, other): - return baseOperatorASyncRequest(self,other, operator.div) + return baseOperatorASyncRequest(self, other, operator.div) + def __rdiv__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.div(y,x)) + return baseOperatorASyncRequest(self, other, lambda x, y: operator.div(y, x)) + def __pow__(self, other): - return baseOperatorASyncRequest(self,other, operator.pow) + return baseOperatorASyncRequest(self, other, operator.pow) + def __rpow__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.pow(y,x)) + return baseOperatorASyncRequest(self, other, lambda x, y: operator.pow(y, x)) + def __mod__(self, other): - return baseOperatorASyncRequest(self,other, operator.mod) + return baseOperatorASyncRequest(self, other, operator.mod) + def __rmod__(self, other): - return baseOperatorASyncRequest(self,other, lambda x,y: operator.mod(y,x)) + return baseOperatorASyncRequest(self, other, lambda x, y: operator.mod(y, x)) + def __neg__(self): return baseOperatorASyncRequest(self, None, operator.neg) - + def __iter__(self): if self._result_index: for i in self._result_index: yield self[i] else: yield self - - #~ def __call__(self): - #~ return self.result() - + + # def __call__(self): + # return self.result() + + class DependentASyncRequest(AbstractASyncRequest): def __init__(self, parent, request_factory): - - self._result_index=None - self.request=None - self.parent=parent + self._result_index = None + self.request = None + self.parent = parent if isinstance(parent, AsyncRequestsPool): - self.parent=PoolDependentASyncRequest(parent) - + self.parent = PoolDependentASyncRequest(parent) + def handler(arg): - result=arg() - self.request=request_factory() + result = arg() + self.request = request_factory() for h in self.result_handlers: self.request.add_result_handler(*h) return result self.parent.add_result_handler(handler) - + self.result_handlers = [] @property @@ -160,19 +181,19 @@ def is_result_set(self): def is_finished(self): if self.request is None: if self.parent.is_finished: - return True + return True else: return False - + return self.request.is_finished - def wait(self): + def wait(self): try: self.parent.waitall() except Exception as ex: - message=str(ex) + message = str(ex) if not message.startswith("Error in dependent call: "): - message="Error in dependent call: "+str(ex) + message = "Error in dependent call: " + str(ex) raise type(ex)(message) if self.request is None: raise Exception("something went wrong (exception of parent?)") @@ -183,13 +204,13 @@ def is_result_available(self): if self.request is None: return False - #~ if not self.parent.is_finished: - #~ return False + # if not self.parent.is_finished: + # return False if self.request is None: return False - #~ raise Exception("something went wrong (exception of parent?)") - + # raise Exception("something went wrong (exception of parent?)") + return self.request.is_result_available() def result(self): @@ -202,20 +223,20 @@ def result(self): @property def results(self): - return self.parent.results+[self.result()] + return self.parent.results + [self.result()] - def add_result_handler(self, function, args = ()): + def add_result_handler(self, function, args=()): if self.request is None: - self.result_handlers.append([function,args]) + self.result_handlers.append([function, args]) else: - self.request.add_result_handler(function,args) + self.request.add_result_handler(function, args) def is_mpi_request(self): if self.request is None: return self.parent.is_mpi_request() else: return self.request.is_mpi_request() - + def is_socket_request(self): if self.request is None: return self.parent.is_socket_request() @@ -230,50 +251,57 @@ def waits_for(self): else: return self.parent.waits_for() + class PoolDependentASyncRequest(DependentASyncRequest): def __init__(self, parent): - self.parent=parent - self.request=FakeASyncRequest() + self.parent = parent + self.request = FakeASyncRequest() self.result_handlers = [] + class IndexedASyncRequest(DependentASyncRequest): def __init__(self, parent, index): - self.parent=parent - self.index=index - self.request=FakeASyncRequest() + self.parent = parent + self.index = index + self.request = FakeASyncRequest() self.result_handlers = [] try: - self._result_index=parent._result_index[index] + self._result_index = parent._result_index[index] except: - self._result_index=None + self._result_index = None def result(self): self.wait() return self.parent.result().__getitem__(self.index) + class baseOperatorASyncRequest(DependentASyncRequest): def __init__(self, first, second, operator): - self._first=first - self._second=second - self._operator=operator - if isinstance( second, AbstractASyncRequest): - pool=AsyncRequestsPool(first,second) - self.parent=PoolDependentASyncRequest(pool) + self._first = first + self._second = second + self._operator = operator + if isinstance(second, AbstractASyncRequest): + pool = AsyncRequestsPool(first, second) + self.parent = PoolDependentASyncRequest(pool) else: - self.parent=first - self.request=FakeASyncRequest() + self.parent = first + self.request = FakeASyncRequest() self.result_handlers = [] - + def result(self): self.wait() - first=self._first.result() - second=self._second.result() if isinstance( self._second, AbstractASyncRequest) else self._second - if second is None: + first = self._first.result() + second = ( + self._second.result() + if isinstance(self._second, AbstractASyncRequest) + else self._second + ) + if second is None: return self._operator(first) - return self._operator(first,second) + return self._operator(first, second) + class ASyncRequest(AbstractASyncRequest): - def __init__(self, request, message, comm, header): self.request = request self.message = message @@ -284,17 +312,17 @@ def __init__(self, request, message, comm, header): self._called_set_result = False self._result = None self.result_handlers = [] - self._result_index=None + self._result_index = None def wait(self): if self.is_finished: return self._is_finished = True - + self.request.Wait() self._set_result() - + def is_result_available(self): if self.is_finished: return self._is_result_set @@ -302,31 +330,31 @@ def is_result_available(self): def get_message(self): return self.message - + def _set_result(self): if self._called_set_result: return - self._called_set_result=True - + self._called_set_result = True + class CallingChain(object): - def __init__(self, outer, args, inner): + def __init__(self, outer, args, inner): self.outer = outer self.inner = inner self.args = args - + def __call__(self): return self.outer(self.inner, *self.args) - + self.message.receive_content(self.comm, self.header) - + current = self.get_message for x, args in self.result_handlers: current = CallingChain(x, args, current) - + self._result = current() - + self._is_result_set = True - + def result(self): self.wait() @@ -340,18 +368,18 @@ def is_mpi_request(self): return False return True + class ASyncSocketRequest(AbstractASyncRequest): - def __init__(self, message, socket): self.message = message self.socket = socket - + self._is_finished = False self._is_result_set = False self._called_set_result = False self._result = None self.result_handlers = [] - self._result_index=None + self._result_index = None def wait(self): if self.is_finished: @@ -362,44 +390,44 @@ def wait(self): while True: readables, _r, _x = select.select([self.socket], [], []) if len(readables) == 1: - break + break self._set_result() def is_result_available(self): if self.is_finished: return self._is_result_set - + readables, _r, _x = select.select([self.socket], [], [], 0.001) - + return len(readables) == 1 - + def get_message(self): return self.message - + def _set_result(self): if self._called_set_result: return - self._called_set_result=True + self._called_set_result = True class CallingChain(object): def __init__(self, outer, args, inner): self.outer = outer self.inner = inner - self.args=args - + self.args = args + def __call__(self): return self.outer(self.inner, *self.args) - + self.message.receive(self.socket) - + current = self.get_message - for x,args in self.result_handlers: + for x, args in self.result_handlers: current = CallingChain(x, args, current) - + self._result = current() - + self._is_result_set = True - + def result(self): self.wait() @@ -413,8 +441,8 @@ def is_socket_request(self): return False return True + class FakeASyncRequest(AbstractASyncRequest): - def __init__(self, result=None): self._is_finished = False self._is_result_set = False @@ -422,40 +450,40 @@ def __init__(self, result=None): self._result = None self.__result = result self.result_handlers = [] - self._result_index=None + self._result_index = None def wait(self): if self.is_finished: return - self._is_finished = True + self._is_finished = True self._set_result() - + def is_result_available(self): return True - + def _set_result(self): if self._called_set_result: return - self._called_set_result=True + self._called_set_result = True class CallingChain(object): - def __init__(self, outer, args, inner): + def __init__(self, outer, args, inner): self.outer = outer self.inner = inner self.args = args - + def __call__(self): return self.outer(self.inner, *self.args) - - current = lambda : self.__result + + current = lambda: self.__result for x, args in self.result_handlers: current = CallingChain(x, args, current) - + self._result = current() - + self._is_result_set = True - + def result(self): self.wait() @@ -464,9 +492,9 @@ def result(self): return self._result + class ASyncRequestSequence(AbstractASyncRequest): - - def __init__(self, create_next_request, args = ()): + def __init__(self, create_next_request, args=()): self.create_next_request = create_next_request self.args = args self.index = 0 @@ -477,7 +505,7 @@ def __init__(self, create_next_request, args = ()): self._result = None self.result_handlers = [] self._results = [] - self._result_index=None + self._result_index = None @property def is_finished(self): @@ -487,12 +515,12 @@ def is_finished(self): def wait(self): if self.is_finished: return - - self._is_finished=True + + self._is_finished = True while self.current_async_request is not None: self.current_async_request.wait() - + self._next_request() self._set_result() @@ -502,67 +530,69 @@ def waitone(self): return self.current_async_request.wait() - + self._next_request() - + if self.current_async_request is None: - self._is_finished=True + self._is_finished = True self._set_result() - def _next_request(self): - if self.current_async_request is not None and \ - self.current_async_request.is_result_available(): + if ( + self.current_async_request is not None + and self.current_async_request.is_result_available() + ): self._results.append(self.current_async_request.result()) self.index += 1 - self.current_async_request = self.create_next_request(self.index, *self.args) + self.current_async_request = self.create_next_request( + self.index, *self.args + ) if self.current_async_request is None: self._set_result() - @property def results(self): return self._results - + def is_result_available(self): if self.is_finished: return True - + self._next_request() - + return self.current_async_request is None - - def add_result_handler(self, function, args = ()): - self.result_handlers.append([function,args]) - + + def add_result_handler(self, function, args=()): + self.result_handlers.append([function, args]) + def get_message(self): return self._results - + def _set_result(self): if self._called_set_result: return - self._called_set_result=True + self._called_set_result = True class CallingChain(object): - def __init__(self, outer, args, inner): + def __init__(self, outer, args, inner): self.outer = outer self.inner = inner self.args = args - + def __call__(self): return self.outer(self.inner, *self.args) - + current = self.get_message for x, args in self.result_handlers: current = CallingChain(x, args, current) - + self._result = current() - + self._is_result_set = True - + def result(self): self.wait() - + if not self._is_result_set: raise Exception("result unexpectedly not available") @@ -577,50 +607,44 @@ def is_socket_request(self): def waits_for(self): return self.current_async_request + class AsyncRequestWithHandler(object): - def __init__(self, pool, async_request, result_handler, args=(), kwargs={}): self.async_request = async_request if result_handler is None: + def empty(request): return request.result() + result_handler = empty self.result_handler = result_handler self.args = args self.kwargs = kwargs self.pool = pool - def run(self): self.result_handler(self.async_request, *self.args, **self.kwargs) - + + class AsyncRequestsPool(object): - def __init__(self, *requests): self.requests_and_handlers = [] self.registered_requests = set([]) for x in requests: self.add_request(x) - - def add_request(self, async_request, result_handler = None, args=(), kwargs={}): + + def add_request(self, async_request, result_handler=None, args=(), kwargs={}): if async_request is None: return if async_request in self.registered_requests: return - #~ raise Exception("Request is already registered, cannot register a request more than once") - + # raise Exception("Request is already registered, cannot register a request more than once") + self.registered_requests.add(async_request) - + self.requests_and_handlers.append( - AsyncRequestWithHandler( - self, - async_request, - result_handler, - args, - kwargs - ) + AsyncRequestWithHandler(self, async_request, result_handler, args, kwargs) ) - def waitall(self): while len(self) > 0: @@ -628,14 +652,21 @@ def waitall(self): def waitone(self): return self.wait() - + def wait(self): - # TODO need to cleanup this code # while len(self.requests_and_handlers) > 0: - requests = [x.async_request.waits_for() for x in self.requests_and_handlers if x.async_request.is_other()] - indices = [i for i, x in enumerate(self.requests_and_handlers) if x.async_request.is_other()] + requests = [ + x.async_request.waits_for() + for x in self.requests_and_handlers + if x.async_request.is_other() + ] + indices = [ + i + for i, x in enumerate(self.requests_and_handlers) + if x.async_request.is_other() + ] if len(requests) > 0: for index, x in zip(indices, requests): if x is not None: @@ -643,46 +674,64 @@ def wait(self): request_and_handler = self.requests_and_handlers[index] if request_and_handler.async_request.is_result_available(): - self.registered_requests.remove(request_and_handler.async_request) - + self.registered_requests.remove( + request_and_handler.async_request + ) + self.requests_and_handlers.pop(index) - + request_and_handler.run() break - requests_ = [x.async_request.waits_for().request for x in self.requests_and_handlers if x.async_request.is_mpi_request()] - indices_ = [i for i, x in enumerate(self.requests_and_handlers) if x.async_request.is_mpi_request()] - - requests=[] - indices=[] - for r,i in zip(requests_, indices_): + requests_ = [ + x.async_request.waits_for().request + for x in self.requests_and_handlers + if x.async_request.is_mpi_request() + ] + indices_ = [ + i + for i, x in enumerate(self.requests_and_handlers) + if x.async_request.is_mpi_request() + ] + + requests = [] + indices = [] + for r, i in zip(requests_, indices_): if r not in requests: requests.append(r) indices.append(i) - + if len(requests) > 0: index = channel.MPI.Request.Waitany(requests) - + index = indices[index] - + request_and_handler = self.requests_and_handlers[index] - + request_and_handler.async_request.waits_for().waitone() # will set the finished flag - + if request_and_handler.async_request.is_result_available(): self.registered_requests.remove(request_and_handler.async_request) - + self.requests_and_handlers.pop(index) - + request_and_handler.run() break - - sockets_ = [x.async_request.waits_for().socket for x in self.requests_and_handlers if x.async_request.is_socket_request()] - indices_ = [i for i, x in enumerate(self.requests_and_handlers) if x.async_request.is_socket_request()] - sockets=[] - indices=[] - for r,i in zip(sockets_, indices_): + sockets_ = [ + x.async_request.waits_for().socket + for x in self.requests_and_handlers + if x.async_request.is_socket_request() + ] + indices_ = [ + i + for i, x in enumerate(self.requests_and_handlers) + if x.async_request.is_socket_request() + ] + + sockets = [] + indices = [] + for r, i in zip(sockets_, indices_): if r not in sockets: sockets.append(r) indices.append(i) @@ -691,27 +740,26 @@ def wait(self): readable, _, _ = select.select(sockets, [], []) indices_to_delete = [] for read_socket in readable: - index = sockets.index(read_socket) - + index = indices[index] - + request_and_handler = self.requests_and_handlers[index] request_and_handler.async_request.waits_for().waitone() # will set the finished flag if request_and_handler.async_request.is_result_available(): - - self.registered_requests.remove(request_and_handler.async_request) - + self.registered_requests.remove( + request_and_handler.async_request + ) + indices_to_delete.append(index) - + request_and_handler.run() - + for x in reversed(list(sorted(indices_to_delete))): - self.requests_and_handlers.pop(x) - + if len(indices_to_delete) > 0: break @@ -723,20 +771,17 @@ def join(self, other): elif isinstance(other, AsyncRequestsPool): for x in other.requests_and_handlers: self.add_request( - x.async_request, - x.result_handler, - args = x.args, - kwargs = x.kwargs - ) + x.async_request, x.result_handler, args=x.args, kwargs=x.kwargs + ) else: raise Exception("can only join request or pool") return self - + def __len__(self): return len(self.requests_and_handlers) - + def __bool__(self): - return len(self)==0 + return len(self) == 0 def waits_for(self): raise Exception("pool has no waits for, should never be called") diff --git a/src/amuse/rfi/channel.py b/src/amuse/rfi/channel.py index d47c092e98..bac733ea38 100644 --- a/src/amuse/rfi/channel.py +++ b/src/amuse/rfi/channel.py @@ -44,9 +44,11 @@ from . import async_request class AbstractMessage(object): - - def __init__(self, - call_id=0, function_id=-1, call_count=1, + def __init__( + self, + call_id=0, + function_id=-1, + call_count=1, dtype_to_arguments={}, error=False, big_endian=(sys.byteorder.lower() == 'big'), @@ -115,7 +117,6 @@ def set_error(self, message): class MPIMessage(AbstractMessage): - def receive(self, comm): header = self.receive_header(comm) self.receive_content(comm, header) @@ -321,7 +322,6 @@ def mpi_send(self, comm, array): class ServerSideMPIMessage(MPIMessage): - def mpi_receive(self, comm, array): request = comm.Irecv(array, source=0, tag=999) request.Wait() @@ -356,7 +356,6 @@ def receive_header(self, comm): class ClientSideMPIMessage(MPIMessage): - def mpi_receive(self, comm, array): comm.Bcast(array, root=0) @@ -504,24 +503,29 @@ def XTERM(cls, full_name_of_the_worker, channel, interpreter_executable=None, im arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) - - command = 'xterm' + + command = "xterm" return command, arguments - @classmethod - def REDIRECT(cls, full_name_of_the_worker, stdoutname, stderrname, command=None, - interpreter_executable=None, run_command_redirected_file=None ): - - fname = run_command_redirected_file or run_command_redirected.__file__ - arguments = [fname , stdoutname, stderrname] - + def REDIRECT( + cls, + full_name_of_the_worker, + stdoutname, + stderrname, + command=None, + interpreter_executable=None, + run_command_redirected_file=None, + ): + fname = run_command_redirected_file or run_command_redirected.__file__ + arguments = [fname, stdoutname, stderrname] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - if command is None : + + if command is None: command = sys.executable return command, arguments @@ -659,7 +663,6 @@ def check_if_worker_is_up_to_date(self, object): """.format(type(object).__name__)) def get_full_name_of_the_worker(self, type): - if os.path.isabs(self.name_of_the_worker): full_name_of_the_worker=self.name_of_the_worker @@ -766,9 +769,10 @@ def get_length(type_and_values): return max(1, max(lengths)) - def split_message(self, call_id, function_id, call_count, dtype_to_arguments, encoded_units = ()): - - if call_count<=1: + def split_message( + self, call_id, function_id, call_count, dtype_to_arguments, encoded_units=() + ): + if call_count <= 1: raise Exception("split message called with call_count<=1") dtype_to_result = {} @@ -1067,7 +1071,6 @@ def is_root(cls): return MPI.COMM_WORLD.rank == 0 def start(self): - logger.debug("starting mpi worker process") logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) @@ -1076,8 +1079,10 @@ def start(self): command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable, immediate_run=self.debugger_immediate_run) else: - if not self.can_redirect_output or (self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none'): - + if not self.can_redirect_output or ( + self.redirect_stdout_file == "none" + and self.redirect_stderr_file == "none" + ): if self.interpreter_executable is None: command = self.full_name_of_the_worker arguments = None @@ -1127,9 +1132,9 @@ def get_length(x): return max(1, max(lengths)) - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = ()): - - + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=() + ): if self.intercomm is None: raise exceptions.CodeException("You've tried to send a message to a code that is not running") @@ -1155,9 +1160,7 @@ def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_unit ) message.send(self.intercomm) - - def recv_message(self, call_id, function_id, handle_as_array, has_units = False): - + def recv_message(self, call_id, function_id, handle_as_array, has_units=False): if self._communicated_splitted_message: x = self._merged_results_splitted_message self._communicated_splitted_message = False @@ -1289,7 +1292,10 @@ def get_info_from_slurm(cls, number_of_workers): host = ','.join(hostnames) print("HOST:", host, cls._scheduler_index, os.environ['SLURM_TASKS_PER_NODE']) info = MPI.Info.Create() - info['host'] = host # actually in mpich and openmpi, the host parameter is interpreted as a comma separated list of host names, + + # actually in mpich and openmpi, the host parameter is interpreted as a + # comma separated list of host names, + info["host"] = host return info @@ -1483,9 +1489,7 @@ def check_mpi(self): class SocketMessage(AbstractMessage): - def _receive_all(self, nbytes, thesocket): - # logger.debug("receiving %d bytes", nbytes) result = [] @@ -1507,7 +1511,6 @@ def _receive_all(self, nbytes, thesocket): return b"" def receive(self, socket): - # logger.debug("receiving message") header_bytes = self._receive_all(44, socket) @@ -1642,22 +1645,26 @@ def nonblocking_receive(self, socket): def send(self, socket): - - flags = numpy.array([self.big_endian, self.error, len(self.encoded_units) > 0, False], dtype="b") + flags = numpy.array( + [self.big_endian, self.error, len(self.encoded_units) > 0, False], dtype="b" + ) + + header = numpy.array( + [ + self.call_id, + self.function_id, + self.call_count, + len(self.ints), + len(self.longs), + len(self.floats), + len(self.doubles), + len(self.booleans), + len(self.strings), + len(self.encoded_units), + ], + dtype="i", + ) - header = numpy.array([ - self.call_id, - self.function_id, - self.call_count, - len(self.ints), - len(self.longs), - len(self.floats), - len(self.doubles), - len(self.booleans), - len(self.strings), - len(self.encoded_units), - ], dtype='i') - # logger.debug("sending message with flags %s and header %s", flags, header) socket.sendall(flags.tobytes()) @@ -1691,12 +1698,15 @@ def send_floats(self, socket, array): def send_strings(self, socket, array): if len(array) > 0: - - lengths = numpy.array( [len(s) for s in array] ,dtype='int32') - chars=(chr(0).join(array)+chr(0)).encode("utf-8") - - if len(chars) != lengths.sum()+len(lengths): - raise Exception("send_strings size mismatch {0} vs {1}".format( len(chars) , lengths.sum()+len(lengths) )) + lengths = numpy.array([len(s) for s in array], dtype="int32") + chars = (chr(0).join(array) + chr(0)).encode("utf-8") + + if len(chars) != lengths.sum() + len(lengths): + raise Exception( + "send_strings size mismatch {0} vs {1}".format( + len(chars), lengths.sum() + len(lengths) + ) + ) self.send_ints(socket, lengths) socket.sendall(chars) @@ -1713,9 +1723,14 @@ def send_longs(self, socket, array): class SocketChannel(AbstractMessageChannel): - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, - remote_env=None, **options): + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + remote_env=None, + **options, + ): AbstractMessageChannel.__init__(self, **options) #logging.getLogger().setLevel(logging.DEBUG) @@ -1790,8 +1805,10 @@ def generate_command_and_arguments(self,server_address,port): if not self.debugger_method is None: command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) else: - if self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none': - + if ( + self.redirect_stdout_file == "none" + and self.redirect_stderr_file == "none" + ): if self.interpreter_executable is None: command = self.full_name_of_the_worker arguments = [] @@ -1837,10 +1854,9 @@ def remote_env_string(self, hostname): else: return "" else: - return "source "+self.remote_env +"\n" + return "source " + self.remote_env + "\n" - def generate_remote_command_and_arguments(self,hostname, server_address,port): - + def generate_remote_command_and_arguments(self, hostname, server_address, port): # get remote config args=["ssh","-T", hostname] @@ -1877,8 +1893,10 @@ def generate_remote_command_and_arguments(self,hostname, server_address,port): raise Exception("remote socket channel debugging not yet supported") #command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) else: - if self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none': - + if ( + self.redirect_stdout_file == "none" + and self.redirect_stderr_file == "none" + ): if interpreter_executable is None: command = full_name_of_the_worker arguments = [] @@ -1921,7 +1939,6 @@ def generate_remote_command_and_arguments(self,hostname, server_address,port): return command,arguments def start(self): - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_address=self.get_host_ip(self.hostname) @@ -2045,9 +2062,10 @@ def get_length(type_and_values): return 1 return max(1, max(lengths)) - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = ()): - + + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=() + ): call_count = self.determine_length_from_data(dtype_to_arguments) # logger.info("sending message for call id %d, function %d, length %d", id, tag, length) @@ -2067,7 +2085,6 @@ def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_unit self._is_inuse = True def recv_message(self, call_id, function_id, handle_as_array, has_units=False): - self._is_inuse = False if self._communicated_splitted_message: @@ -2162,7 +2179,6 @@ def makedirs(self,directory): class OutputHandler(threading.Thread): - def __init__(self, stream, port): threading.Thread.__init__(self) self.stream = stream @@ -2194,7 +2210,6 @@ def __init__(self, stream, port): self.start() def run(self): - while True: # logger.debug("receiving data for output") data = self.socket.recv(1024) @@ -2209,7 +2224,6 @@ def run(self): class DistributedChannel(AbstractMessageChannel): - default_distributed_instance = None @staticmethod @@ -2399,9 +2413,10 @@ def get_length(x): return 1 return max(1, max(lengths)) - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = None): - + + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=None + ): call_count = self.determine_length_from_data(dtype_to_arguments) logger.debug("sending message for call id %d, function %d, length %d", call_id, function_id, call_count) @@ -2421,7 +2436,6 @@ def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_unit def recv_message(self, call_id, function_id, handle_as_array, has_units=False): - self._is_inuse = False if self._communicated_splitted_message: @@ -2437,7 +2451,7 @@ def recv_message(self, call_id, function_id, handle_as_array, has_units=False): if message.error: error_message=message.strings[0] if len(message.strings)>0 else "no error message" if message.call_id != call_id or message.function_id != function_id: - #~ self.stop() + # self.stop() error_message+=" - code probably died, sorry." raise exceptions.CodeException("Error in worker: " + error_message) @@ -2490,11 +2504,15 @@ def max_message_length(self): return 1000000 class LocalChannel(AbstractMessageChannel): - - - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, - distributed_instance=None, dynamic_python_code=False, **options): + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + distributed_instance=None, + dynamic_python_code=False, + **options, + ): AbstractMessageChannel.__init__(self, **options) MpiChannel.ensure_mpi_initialized() @@ -2537,11 +2555,10 @@ def is_active(self): def is_inuse(self): return self._is_inuse - - - - def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_units = None): - + + def send_message( + self, call_id, function_id, dtype_to_arguments={}, encoded_units=None + ): call_count = self.determine_length_from_data(dtype_to_arguments) self.message = LocalMessage(call_id, function_id, call_count, dtype_to_arguments, encoded_units = encoded_units) diff --git a/src/amuse/rfi/gencode.py b/src/amuse/rfi/gencode.py index 4f2f10adae..5ba79a900e 100755 --- a/src/amuse/rfi/gencode.py +++ b/src/amuse/rfi/gencode.py @@ -38,14 +38,6 @@ def get_amuse_directory(): else: return os.path.abspath(directory_of_this_script) -# in case of trouble consult old python 2: - #~ def get_amuse_directory(): - #~ filename_of_this_script = __file__ - #~ directory_of_this_script = os.path.dirname(os.path.dirname(filename_of_this_script)) - #~ if os.path.isabs(directory_of_this_script): - #~ return directory_of_this_script - #~ else: - #~ return os.path.abspath(directory_of_this_script) def setup_sys_path(): amuse_directory = os.environ["AMUSE_DIR"] @@ -292,7 +284,7 @@ def module_name(string): ) ) - string = string[len(amuse_src_directory) + 1:] + string = string[len(amuse_src_directory) + 1 :] string = string[: -len(".py")] string = string.replace(os.sep, ".") return string diff --git a/src/amuse/rfi/import_module.py b/src/amuse/rfi/import_module.py index e3f5a40fa1..f664e388e8 100644 --- a/src/amuse/rfi/import_module.py +++ b/src/amuse/rfi/import_module.py @@ -11,6 +11,7 @@ class _ModuleRegister(object): is_cleanup_registered = False files_to_cleanup = [] + def find_shared_object_file(dirpath, base_libname): for path in sys.path: fullpath = os.path.join(path, dirpath) @@ -19,24 +20,28 @@ def find_shared_object_file(dirpath, base_libname): if os.path.exists(full_libname): return full_libname return base_libname - + + def find_module(modulename): - parts = modulename.split('.') + parts = modulename.split(".") modulename = parts[-1] dirparts = parts[:-1] - base_libname = modulename + '.so' + base_libname = modulename + ".so" if len(dirparts) > 0: dirpath = os.path.join(*dirparts) else: - dirpath = '' + dirpath = "" libname = find_shared_object_file(dirpath, base_libname) if not os.path.exists(libname): - raise Exception("cannot find the shared object file of the module '{0}'".format(modulename)) + raise Exception( + "cannot find the shared object file of the module '{0}'".format(modulename) + ) return modulename, libname - + + def import_unique(modulename): modulename, libname = find_module(modulename) - + if modulename in sys.modules: prevmodule = sys.modules[modulename] else: @@ -44,19 +49,21 @@ def import_unique(modulename): if not _ModuleRegister.is_cleanup_registered: _ModuleRegister.is_cleanup_registered = True atexit.register(cleanup) - - if not os.path.exists('__modules__'): - os.mkdir('__modules__') + + if not os.path.exists("__modules__"): + os.mkdir("__modules__") try: - with tempfile.NamedTemporaryFile(suffix=".so", dir='__modules__', delete=False) as target: + with tempfile.NamedTemporaryFile( + suffix=".so", dir="__modules__", delete=False + ) as target: with open(libname, "rb") as source: shutil.copyfileobj(source, target) target.flush() - + _ModuleRegister.files_to_cleanup.append(target.name) - + lib = ctypes.pydll.LoadLibrary(target.name) - initfunc = getattr(lib, "init"+modulename) + initfunc = getattr(lib, "init" + modulename) initfunc() result = sys.modules[modulename] result.__ctypeslib__ = lib @@ -68,64 +75,71 @@ def import_unique(modulename): del sys.modules[modulename] else: sys.modules[modulename] = prevmodule - + + def cleanup(): for filename in _ModuleRegister.files_to_cleanup: if os.path.exists(filename): try: os.remove(filename) except Exception as ex: - print("Could not delete file:",filename,", exception:",ex) + print("Could not delete file:", filename, ", exception:", ex) + # this struct will be passed as a ponter, # so we don't have to worry about the right layout class dl_phdr_info(ctypes.Structure): - _fields_ = [ - ('padding0', ctypes.c_void_p), # ignore it - ('dlpi_name', ctypes.c_char_p), - # ignore the reset - ] + _fields_ = [ + ("padding0", ctypes.c_void_p), # ignore it + ("dlpi_name", ctypes.c_char_p), + # ignore the reset + ] # call back function, I changed c_void_p to c_char_p -callback_t = ctypes.CFUNCTYPE(ctypes.c_int, - ctypes.POINTER(dl_phdr_info), - ctypes.POINTER(ctypes.c_size_t), ctypes.c_char_p) +callback_t = ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.POINTER(dl_phdr_info), + ctypes.POINTER(ctypes.c_size_t), + ctypes.c_char_p, +) -dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr +dl_iterate_phdr = ctypes.CDLL("libc.so.6").dl_iterate_phdr # I changed c_void_p to c_char_p dl_iterate_phdr.argtypes = [callback_t, ctypes.c_char_p] dl_iterate_phdr.restype = ctypes.c_int count = [0] + + def callback(info, size, data): # simple search print("CLEANUP:", info.contents.dlpi_name) count[0] += 1 return 0 - + + def cleanup_module(mod): - #print "CLEANUP!!" - #sys.stdout.flush() - #print "CLEANUP:", mod, len(list(os.listdir('/proc/self/fd'))) - #count[0] = 0 - #dl_iterate_phdr(callback_t(callback), "") - #print "CLEANUP:", count[0] + # print "CLEANUP!!" + # sys.stdout.flush() + # print "CLEANUP:", mod, len(list(os.listdir('/proc/self/fd'))) + # count[0] = 0 + # dl_iterate_phdr(callback_t(callback), "") + # print "CLEANUP:", count[0] sys.stdout.flush() - - if hasattr(mod, '__ctypeslib__') and not mod.__ctypeslib__ is None: + + if hasattr(mod, "__ctypeslib__") and not mod.__ctypeslib__ is None: lib = mod.__ctypeslib__ - dlclose = ctypes.cdll.LoadLibrary('libdl.so').dlclose + dlclose = ctypes.cdll.LoadLibrary("libdl.so").dlclose dlclose.argtypes = [ctypes.c_void_p] dlclose.restype = ctypes.c_int - errorcode = dlclose(lib._handle) + errorcode = dlclose(lib._handle) mod.__ctypeslib__ = None filename = mod.__ctypesfilename__ if os.path.exists(filename): try: os.remove(filename) except Exception as ex: - print("CLEANUP Could not delete file:",filename,", exception:",ex) + print("CLEANUP Could not delete file:", filename, ", exception:", ex) mod.__ctypesfilename__ = None - diff --git a/src/amuse/rfi/nospawn.py b/src/amuse/rfi/nospawn.py index 94e2235fdd..da3733ec15 100644 --- a/src/amuse/rfi/nospawn.py +++ b/src/amuse/rfi/nospawn.py @@ -1,4 +1,3 @@ - from amuse.rfi import core from amuse.rfi.python_code import CythonImplementation from mpi4py import MPI @@ -7,8 +6,11 @@ import sys import importlib -Code = namedtuple("Code", ['cls', 'number_of_workers', 'args', 'kwargs']) -PythonCode = namedtuple("Code", ['cls', 'number_of_workers', 'args', 'kwargs', 'implementation_factory']) +Code = namedtuple("Code", ["cls", "number_of_workers", "args", "kwargs"]) +PythonCode = namedtuple( + "Code", ["cls", "number_of_workers", "args", "kwargs", "implementation_factory"] +) + def get_number_of_workers_needed(codes): result = 1 @@ -16,6 +18,7 @@ def get_number_of_workers_needed(codes): result += x.number_of_workers return result + def get_color(rank, codes): if rank == 0: return 0 @@ -25,8 +28,9 @@ def get_color(rank, codes): if rank >= index and rank < index + x.number_of_workers: return color + 1 index += x.number_of_workers - return len(codes) + 1 #left over ranks - + return len(codes) + 1 # left over ranks + + def get_key(rank, codes): if rank == 0: return 0 @@ -36,7 +40,8 @@ def get_key(rank, codes): if rank >= index and rank < index + x.number_of_workers: return rank - index index += x.number_of_workers - return rank - (len(codes) + 1) #left over ranks + return rank - (len(codes) + 1) # left over ranks + def get_code_class(rank, codes): if rank == 0: @@ -48,28 +53,29 @@ def get_code_class(rank, codes): return x.cls index += x.number_of_workers return None - - + def start_all(codes): - channel.MpiChannel.ensure_mpi_initialized() number_of_workers_needed = get_number_of_workers_needed(codes) - + world = MPI.COMM_WORLD rank = world.rank if world.size < number_of_workers_needed: if rank == 0: - raise Exception("cannot start all codes, the world size ({0}) is smaller than the number of requested codes ({1}) (which is always 1 + the sum of the all the number_of_worker fields)".format(world.size, number_of_workers_needed)) + raise Exception( + "cannot start all codes, the world size ({0}) is smaller than the number of requested codes ({1}) (which is always 1 + the sum of the all the number_of_worker fields)".format( + world.size, number_of_workers_needed + ) + ) else: return None - + color = get_color(world.rank, codes) key = get_key(world.rank, codes) - + newcomm = world.Split(color, key) - - + localdup = world.Dup() if world.rank == 0: result = [] @@ -79,78 +85,80 @@ def start_all(codes): new_intercomm = newcomm.Create_intercomm(0, localdup, remote_leader, tag) remote_leader += x.number_of_workers tag += 1 - instance = x.cls(*x.args, check_mpi = False, must_start_worker = False, **x.kwargs) - instance.legacy_interface.channel = channel.MpiChannel('_',None) + instance = x.cls( + *x.args, check_mpi=False, must_start_worker=False, **x.kwargs + ) + instance.legacy_interface.channel = channel.MpiChannel("_", None) instance.legacy_interface.channel.intercomm = new_intercomm result.append(instance) - + world.Barrier() - - return result + + return result else: code_cls = get_code_class(world.rank, codes) if code_cls is None: world.Barrier() return None - + new_intercomm = newcomm.Create_intercomm(0, localdup, 0, color) x = get_code(world.rank, codes) - instance = code_cls(*x.args, check_mpi = False, must_start_worker = False, **x.kwargs) + instance = code_cls( + *x.args, check_mpi=False, must_start_worker=False, **x.kwargs + ) interface = instance.legacy_interface - - if hasattr(interface, '__so_module__'): - package, _ = code_cls.__module__.rsplit('.',1) - modulename = package + '.' + interface.__so_module__ + + if hasattr(interface, "__so_module__"): + package, _ = code_cls.__module__.rsplit(".", 1) + modulename = package + "." + interface.__so_module__ module = importlib.import_module(modulename) module.set_comm_world(newcomm) else: module = x.implementation_factory() - - + instance = CythonImplementation(module, interface.__class__) instance.intercomm = new_intercomm instance.must_disconnect = False world.Barrier() instance.start() - + return None - + def stop_all(instances): for x in instances: x.stop() + + def start_empty(): - channel.MpiChannel.ensure_mpi_initialized() - + world = MPI.COMM_WORLD rank = world.rank - + color = 0 if world.rank == 0 else 1 - key = 0 if world.rank == 0 else world.rank -1 + key = 0 if world.rank == 0 else world.rank - 1 newcomm = world.Split(color, key) - + localdup = world.Dup() if world.rank == 0: result = [] remote_leader = 1 tag = 1 - + new_intercomm = newcomm.Create_intercomm(0, localdup, remote_leader, tag) - - instance = core.CodeInterface(check_mpi = False, must_start_worker = False) - instance.channel = channel.MpiChannel('_',None) + + instance = core.CodeInterface(check_mpi=False, must_start_worker=False) + instance.channel = channel.MpiChannel("_", None) instance.channel.intercomm = new_intercomm - instance.world = localdup + instance.world = localdup instance.remote_leader = 1 world.Barrier() - - return instance + + return instance else: - new_intercomm = newcomm.Create_intercomm(0, localdup, 0, color) - - + instance = CythonImplementation(None, core.CodeInterface) instance.intercomm = new_intercomm instance.world = localdup @@ -161,7 +169,6 @@ def start_empty(): instance.start() print("STOP...", world.rank) return None - def get_code(rank, codes): @@ -174,7 +181,3 @@ def get_code(rank, codes): return x index += x.number_of_workers return None - - - - diff --git a/src/amuse/rfi/python_code.py b/src/amuse/rfi/python_code.py index 706a7ee41a..6bfa2d05d2 100644 --- a/src/amuse/rfi/python_code.py +++ b/src/amuse/rfi/python_code.py @@ -18,28 +18,28 @@ from amuse.rfi.core import legacy_function from amuse.rfi.core import LegacyFunctionSpecification + class ValueHolder(object): - - def __init__(self, value = None): + def __init__(self, value=None): self.value = value - + def __repr__(self): return "V({0!r})".format(self.value) def __str__(self): return "V({0!s})".format(self.value) - + class PythonImplementation(object): - dtype_to_message_attribute = { - 'int32' : 'ints', - 'float64' : 'doubles', - 'float32' : 'floats', - 'string' : 'strings', - 'bool' : 'booleans', - 'int64' : 'longs', + dtype_to_message_attribute = { + "int32": "ints", + "float64": "doubles", + "float32": "floats", + "string": "strings", + "bool": "booleans", + "int64": "longs", } - + def __init__(self, implementation, interface): self.implementation = implementation self.interface = interface @@ -51,9 +51,8 @@ def __init__(self, implementation, interface): self.id_to_activate = -1 if not self.implementation is None: self.implementation._interface = self - - def start(self, mpi_port = None): + def start(self, mpi_port=None): if mpi_port is None: parent = self.intercomm self.communicators.append(parent) @@ -62,22 +61,24 @@ def start(self, mpi_port = None): self.communicators.append(parent) self.activeid = 0 self.lastid += 1 - + rank = parent.Get_rank() - + self.must_run = True while self.must_run: if self.id_to_activate >= 0 and self.id_to_activate != self.activeid: - warnings.warn("activating: "+str(self.id_to_activate)) + warnings.warn("activating: " + str(self.id_to_activate)) self.activeid = self.id_to_activate self.id_to_activate = -1 parent = self.communicators[self.activeid] rank = parent.Get_rank() - message = ClientSideMPIMessage(polling_interval = self.polling_interval) + message = ClientSideMPIMessage(polling_interval=self.polling_interval) message.receive(parent) - result_message = ClientSideMPIMessage(message.call_id, message.function_id, message.call_count) - + result_message = ClientSideMPIMessage( + message.call_id, message.function_id, message.call_count + ) + if message.function_id == 0: self.must_run = False else: @@ -88,183 +89,214 @@ def start(self, mpi_port = None): warnings.warn(str(ex)) traceback.print_exc() result_message.set_error(str(ex)) - #for type, attribute in self.dtype_to_message_attribute.iteritems(): + # for type, attribute in self.dtype_to_message_attribute.iteritems(): # setattr(result_message, attribute, []) - + for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(result_message, attribute) packed = pack_array(array, result_message.call_count, type) setattr(result_message, attribute, packed) - + else: - result_message.set_error("unknown function id " + str(message.function_id)) - + result_message.set_error( + "unknown function id " + str(message.function_id) + ) + if rank == 0: result_message.send(parent) if self.must_disconnect: for x in self.communicators: x.Disconnect() - - - - - def start_socket(self, port, host): client_socket = socket.create_connection((host, port)) - + self.must_run = True while self.must_run: - message = SocketMessage() message.receive(client_socket) - - result_message = SocketMessage(message.call_id, message.function_id, message.call_count) - + + result_message = SocketMessage( + message.call_id, message.function_id, message.call_count + ) + if message.function_id == 0: self.must_run = False else: if message.function_id in self.mapping_from_tag_to_legacy_function: try: self.handle_message(message, result_message) - except BaseException as ex: + except BaseException as ex: traceback.print_exc() result_message.set_error(ex.__str__()) for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(result_message, attribute) packed = pack_array(array, result_message.call_count, type) setattr(result_message, attribute, packed) - + else: - result_message.set_error("unknown function id " + message.function_id) - + result_message.set_error( + "unknown function id " + message.function_id + ) + result_message.send(client_socket) - + client_socket.close() - + def start_socket_mpi(self, port, host): - rank=MPI.COMM_WORLD.Get_rank() + rank = MPI.COMM_WORLD.Get_rank() - if rank==0: + if rank == 0: client_socket = socket.create_connection((host, port)) - + self.must_run = True while self.must_run: - - if rank==0: + if rank == 0: message = SocketMessage() message.receive(client_socket) else: - message=None - - message=MPI.COMM_WORLD.bcast(message, root=0) - - result_message = SocketMessage(message.call_id, message.function_id, message.call_count) - + message = None + + message = MPI.COMM_WORLD.bcast(message, root=0) + + result_message = SocketMessage( + message.call_id, message.function_id, message.call_count + ) + if message.function_id == 0: self.must_run = False else: if message.function_id in self.mapping_from_tag_to_legacy_function: try: self.handle_message(message, result_message) - except BaseException as ex: + except BaseException as ex: traceback.print_exc() result_message.set_error(ex.__str__()) for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(result_message, attribute) packed = pack_array(array, result_message.call_count, type) setattr(result_message, attribute, packed) - + else: - result_message.set_error("unknown function id " + message.function_id) - - if rank==0: + result_message.set_error( + "unknown function id " + message.function_id + ) + + if rank == 0: result_message.send(client_socket) - - if rank==0: - client_socket.close() + if rank == 0: + client_socket.close() def handle_message(self, input_message, output_message): - legacy_function = self.mapping_from_tag_to_legacy_function[input_message.function_id] + legacy_function = self.mapping_from_tag_to_legacy_function[ + input_message.function_id + ] specification = legacy_function.specification dtype_to_count = self.get_dtype_to_count(specification) - - + if hasattr(specification, "internal_provided"): method = getattr(self, specification.name) else: method = getattr(self.implementation, specification.name) - + if specification.has_units: input_units = self.convert_floats_to_units(input_message.encoded_units) else: input_units = () - + for type, attribute in self.dtype_to_message_attribute.items(): - count = dtype_to_count.get(type,0) + count = dtype_to_count.get(type, 0) for x in range(count): - if type == 'string': - getattr(output_message, attribute).append([""] * output_message.call_count) + if type == "string": + getattr(output_message, attribute).append( + [""] * output_message.call_count + ) else: - getattr(output_message, attribute).append(numpy.zeros(output_message.call_count, dtype=type)) + getattr(output_message, attribute).append( + numpy.zeros(output_message.call_count, dtype=type) + ) for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(input_message, attribute) unpacked = unpack_array(array, input_message.call_count, type) - setattr(input_message,attribute, unpacked) - + setattr(input_message, attribute, unpacked) + units = [False] * len(specification.output_parameters) if specification.must_handle_array: - keyword_arguments = self.new_keyword_arguments_from_message(input_message, None, specification, input_units) - try: + keyword_arguments = self.new_keyword_arguments_from_message( + input_message, None, specification, input_units + ) + try: result = method(**keyword_arguments) except TypeError as ex: - warnings.warn("mismatch in python function specification(?): "+str(ex)) + warnings.warn( + "mismatch in python function specification(?): " + str(ex) + ) result = method(*list(keyword_arguments)) - self.fill_output_message(output_message, None, result, keyword_arguments, specification, units) + self.fill_output_message( + output_message, None, result, keyword_arguments, specification, units + ) else: for index in range(input_message.call_count): - keyword_arguments = self.new_keyword_arguments_from_message(input_message, index, specification, input_units) + keyword_arguments = self.new_keyword_arguments_from_message( + input_message, index, specification, input_units + ) try: result = method(**keyword_arguments) except TypeError as ex: - warnings.warn("mismatch in python function specification(?): "+str(ex)) + warnings.warn( + "mismatch in python function specification(?): " + str(ex) + ) result = method(*list(keyword_arguments)) - self.fill_output_message(output_message, index, result, keyword_arguments, specification, units) - - + self.fill_output_message( + output_message, + index, + result, + keyword_arguments, + specification, + units, + ) + for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(output_message, attribute) packed = pack_array(array, input_message.call_count, type) setattr(output_message, attribute, packed) - + if specification.has_units: output_message.encoded_units = self.convert_output_units_to_floats(units) - - - def new_keyword_arguments_from_message(self, input_message, index, specification, units = []): + def new_keyword_arguments_from_message( + self, input_message, index, specification, units=[] + ): keyword_arguments = OrderedDictionary() for parameter in specification.parameters: attribute = self.dtype_to_message_attribute[parameter.datatype] argument_value = None if parameter.direction == LegacyFunctionSpecification.IN: if specification.must_handle_array: - argument_value = getattr(input_message, attribute)[parameter.input_index] + argument_value = getattr(input_message, attribute)[ + parameter.input_index + ] else: - argument_value = getattr(input_message, attribute)[parameter.input_index][index] + argument_value = getattr(input_message, attribute)[ + parameter.input_index + ][index] if specification.has_units: unit = units[parameter.index_in_input] if not unit is None: argument_value = argument_value | unit elif parameter.direction == LegacyFunctionSpecification.INOUT: if specification.must_handle_array: - argument_value = ValueHolder(getattr(input_message, attribute)[parameter.input_index]) + argument_value = ValueHolder( + getattr(input_message, attribute)[parameter.input_index] + ) else: - argument_value = ValueHolder(getattr(input_message, attribute)[parameter.input_index][index]) - + argument_value = ValueHolder( + getattr(input_message, attribute)[parameter.input_index][index] + ) + if specification.has_units: unit = units[parameter.index_in_input] if not unit is None: @@ -273,25 +305,28 @@ def new_keyword_arguments_from_message(self, input_message, index, specification argument_value = ValueHolder(None) elif parameter.direction == LegacyFunctionSpecification.LENGTH: argument_value = input_message.call_count - name = 'in_' if parameter.name == 'in' else parameter.name + name = "in_" if parameter.name == "in" else parameter.name keyword_arguments[name] = argument_value return keyword_arguments - - def fill_output_message(self, output_message, index, result, keyword_arguments, specification, units): + def fill_output_message( + self, output_message, index, result, keyword_arguments, specification, units + ): from amuse.units import quantities - + if not specification.result_type is None: attribute = self.dtype_to_message_attribute[specification.result_type] if specification.must_handle_array: getattr(output_message, attribute)[0] = result else: getattr(output_message, attribute)[0][index] = result - + for parameter in specification.parameters: attribute = self.dtype_to_message_attribute[parameter.datatype] - if (parameter.direction == LegacyFunctionSpecification.OUT or - parameter.direction == LegacyFunctionSpecification.INOUT): + if ( + parameter.direction == LegacyFunctionSpecification.OUT + or parameter.direction == LegacyFunctionSpecification.INOUT + ): argument_value = keyword_arguments[parameter.name] output = argument_value.value if specification.has_units: @@ -305,68 +340,70 @@ def fill_output_message(self, output_message, index, result, keyword_arguments, if specification.must_handle_array: getattr(output_message, attribute)[parameter.output_index] = output else: - getattr(output_message, attribute)[parameter.output_index][index] = output - + getattr(output_message, attribute)[parameter.output_index][ + index + ] = output + def get_dtype_to_count(self, specification): dtype_to_count = {} - + for parameter in specification.output_parameters: count = dtype_to_count.get(parameter.datatype, 0) dtype_to_count[parameter.datatype] = count + 1 - + if not specification.result_type is None: count = dtype_to_count.get(specification.result_type, 0) dtype_to_count[specification.result_type] = count + 1 - + return dtype_to_count - + @late def mapping_from_tag_to_legacy_function(self): result = {} for x in self.interface_functions: result[x.specification.id] = x return result - + @late def interface_functions(self): attribute_names = dir(self.interface) interface_functions = [] for x in attribute_names: - if x.startswith('__'): + if x.startswith("__"): continue value = getattr(self.interface, x) if isinstance(value, legacy_function): interface_functions.append(value) - - interface_functions.sort(key= lambda x: x.specification.id) - + + interface_functions.sort(key=lambda x: x.specification.id) + for x in interface_functions: x.specification.prepare_output_parameters() - + return interface_functions - - def internal__set_message_polling_interval(self, inval): self.polling_interval = inval return 0 - + def internal__get_message_polling_interval(self, outval): - outval.value = self.polling_interval + outval.value = self.polling_interval return 0 - + def get_null_info(self): - return getattr(MPI, 'INFO_NULL') if hasattr(MPI, 'INFO_NULL') else None - + return getattr(MPI, "INFO_NULL") if hasattr(MPI, "INFO_NULL") else None + def internal__open_port(self, port_identifier): port_identifier.value = MPI.Open_port(self.get_null_info()) return 0 - + def internal__accept_on_port(self, port_identifier, comm_identifier): new_communicator = None rank = MPI.COMM_WORLD.Get_rank() if rank == 0: - communicator = MPI.COMM_SELF.Accept(port_identifier, self.get_null_info(), 0) + communicator = MPI.COMM_SELF.Accept( + port_identifier, self.get_null_info(), 0 + ) merged = communicator.Merge(False) new_communicator = MPI.COMM_WORLD.Create_intercomm(0, merged, 1, 65) @@ -374,18 +411,19 @@ def internal__accept_on_port(self, port_identifier, comm_identifier): communicator.Free() else: new_communicator = MPI.COMM_WORLD.Create_intercomm(0, MPI.COMM_WORLD, 1, 65) - + self.communicators.append(new_communicator) self.lastid += 1 comm_identifier.value = self.lastid return 0 - - + def internal__connect_to_port(self, port_identifier, comm_identifier): new_communicator = None rank = MPI.COMM_WORLD.Get_rank() if rank == 0: - communicator = MPI.COMM_SELF.Connect(port_identifier, self.get_null_info(), 0) + communicator = MPI.COMM_SELF.Connect( + port_identifier, self.get_null_info(), 0 + ) merged = communicator.Merge(True) new_communicator = MPI.COMM_WORLD.Create_intercomm(0, merged, 0, 65) @@ -394,49 +432,46 @@ def internal__connect_to_port(self, port_identifier, comm_identifier): communicator.Free() else: new_communicator = MPI.COMM_WORLD.Create_intercomm(0, MPI.COMM_WORLD, 0, 65) - + self.communicators.append(new_communicator) self.lastid += 1 comm_identifier.value = self.lastid return 0 - + def internal__activate_communicator(self, comm_identifier): if comm_identifier > self.lastid or comm_identifier < 0: return -1 self.id_to_activate = comm_identifier return 0 - - - + def internal__redirect_outputs(self, stdoutfile, stderrfile): mpi_rank = MPI.COMM_WORLD.rank sys.stdin.close() try: os.close(0) except Exception as ex: - warnings.warn( str(ex)) - + warnings.warn(str(ex)) + if stdoutfile != "none": if stdoutfile != "/dev/null": fullname = "{0:s}.{1:03d}".format(stdoutfile, mpi_rank) else: fullname = stdoutfile - - + sys.stdout.close() sys.stdout = open(fullname, "a+") - + if stderrfile != "none": - if stderrfile != "/dev/null": + if stderrfile != "/dev/null": fullname = "{0:s}.{1:03d}".format(stderrfile, mpi_rank) else: fullname = stderrfile - + sys.stderr.close() - sys.stderr = open(fullname, "a+") - + sys.stderr = open(fullname, "a+") + return 0 - + def convert_to_unit(self, units_as_floats, index): return None @@ -447,16 +482,16 @@ def convert_unit_to_floats(self, unit): return unit.to_array_of_floats() def convert_output_units_to_floats(self, units): - result = numpy.zeros(len(units) * 9, dtype = numpy.float64) + result = numpy.zeros(len(units) * 9, dtype=numpy.float64) for index, unit in enumerate(units): - offset = index*9 - result[offset:offset+9] = self.convert_unit_to_floats(unit) + offset = index * 9 + result[offset : offset + 9] = self.convert_unit_to_floats(unit) return result - def convert_float_to_unit(self, floats): from amuse.units import core from amuse.units import units + if numpy.all(floats == 0): return None factor = floats[0] @@ -470,25 +505,22 @@ def convert_float_to_unit(self, floats): for x in unit_system.bases: power = floats[x.index + 2] if not power == 0.0: - result = result * (x ** power) + result = result * (x**power) return result - - def convert_floats_to_units(self, floats): result = [] for index in range(len(floats) // 9): - offset = index*9 - unit_floats = floats[offset:offset+9] + offset = index * 9 + unit_floats = floats[offset : offset + 9] unit = self.convert_float_to_unit(unit_floats) result.append(unit) return result - - @late def intercomm(self): return MPI.Comm.Get_parent() + @late def must_disconnect(self): return True @@ -498,8 +530,12 @@ def internal__become_code(self, number_of_workers, modulename, classname): #~ print number_of_workers, modulename, classname world = self.freeworld color = 0 if world.rank < number_of_workers else 1 - key = world.rank if world.rank < number_of_workers else world.rank - number_of_workers - #~ print "CC,", color, key, world.rank, world.size + key = ( + world.rank + if world.rank < number_of_workers + else world.rank - number_of_workers + ) + # print "CC,", color, key, world.rank, world.size newcomm = world.Split(color, key) #~ print ("nc:", newcomm.size, newcomm.rank) #~ print ("AA", self.world, color, self.world.rank, self.world.size) @@ -508,90 +544,98 @@ def internal__become_code(self, number_of_workers, modulename, classname): except Exception as ex: warnings.warn(str(ex)) raise ex - #~ print ("nccc:", new_intercomm.Get_remote_size(), new_intercomm.rank) - + # print ("nccc:", new_intercomm.Get_remote_size(), new_intercomm.rank) + self.communicators.append(new_intercomm) self.id_to_activate = len(self.communicators) - 1 self.freeworld = newcomm return 0 - + def set_working_directory(self, d): try: - os.chdir(d) - return 0 + os.chdir(d) + return 0 except Exception: - return -1 + return -1 def get_working_directory(self, d): try: - d.value=os.getcwd() - return 0 + d.value = os.getcwd() + return 0 except Exception: - return -1 - - + return -1 class CythonImplementation(PythonImplementation): - - def handle_message(self, input_message, output_message): - legacy_function = self.mapping_from_tag_to_legacy_function[input_message.function_id] + legacy_function = self.mapping_from_tag_to_legacy_function[ + input_message.function_id + ] specification = legacy_function.specification - + dtype_to_count = self.get_dtype_to_count(specification) - - if specification.name == '_stop_worker': - method = lambda : None - elif hasattr(specification,"internal_provided"): + + if specification.name == "_stop_worker": + method = lambda: None + elif hasattr(specification, "internal_provided"): method = getattr(self, specification.name) else: method = getattr(self.implementation, specification.name) - + if specification.has_units: input_units = self.convert_floats_to_units(input_message.encoded_units) else: input_units = () - + for type, attribute in self.dtype_to_message_attribute.items(): - count = dtype_to_count.get(type,0) + count = dtype_to_count.get(type, 0) for x in range(count): - if type == 'string': - getattr(output_message, attribute).append([""] * output_message.call_count) + if type == "string": + getattr(output_message, attribute).append( + [""] * output_message.call_count + ) else: - getattr(output_message, attribute).append(numpy.zeros(output_message.call_count, dtype=type)) + getattr(output_message, attribute).append( + numpy.zeros(output_message.call_count, dtype=type) + ) for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(input_message, attribute) unpacked = unpack_array(array, input_message.call_count, type) - setattr(input_message,attribute, unpacked) - + setattr(input_message, attribute, unpacked) + units = [False] * len(specification.output_parameters) if specification.must_handle_array: - keyword_arguments = self.new_keyword_arguments_from_message(input_message, None, specification, input_units) + keyword_arguments = self.new_keyword_arguments_from_message( + input_message, None, specification, input_units + ) result = method(**keyword_arguments) - self.fill_output_message(output_message, None, result, keyword_arguments, specification, units) + self.fill_output_message( + output_message, None, result, keyword_arguments, specification, units + ) else: for index in range(input_message.call_count): - #print "INDEX:", index - keyword_arguments = self.new_keyword_arguments_from_message(input_message, index, specification, input_units) + # print "INDEX:", index + keyword_arguments = self.new_keyword_arguments_from_message( + input_message, index, specification, input_units + ) try: result = method(**keyword_arguments) except TypeError as ex: result = method(*list(keyword_arguments)) - self.fill_output_message(output_message, index, result, keyword_arguments, specification, units) - - + self.fill_output_message( + output_message, + index, + result, + keyword_arguments, + specification, + units, + ) + for type, attribute in self.dtype_to_message_attribute.items(): array = getattr(output_message, attribute) packed = pack_array(array, input_message.call_count, type) setattr(output_message, attribute, packed) - + if specification.has_units: output_message.encoded_units = self.convert_output_units_to_floats(units) - - - - - - diff --git a/src/amuse/rfi/run_command_redirected.py b/src/amuse/rfi/run_command_redirected.py index 454dc887bc..3f2f0e7257 100644 --- a/src/amuse/rfi/run_command_redirected.py +++ b/src/amuse/rfi/run_command_redirected.py @@ -5,49 +5,46 @@ import time import signal + def translate_filename_for_os(filename): - if sys.platform == 'win32': - if filename == '/dev/null': - return 'nul' + if sys.platform == "win32": + if filename == "/dev/null": + return "nul" else: return filename else: return filename -if __name__ == '__main__': + +if __name__ == "__main__": stdoutfname = None - if sys.argv[1] == 'none': + if sys.argv[1] == "none": stdout = None else: - stdoutfname=translate_filename_for_os(sys.argv[1]) - stdout = open(stdoutfname,'w') - - if sys.argv[2] == 'none': + stdoutfname = translate_filename_for_os(sys.argv[1]) + stdout = open(stdoutfname, "w") + + if sys.argv[2] == "none": stderr = None else: - stderrfname=translate_filename_for_os(sys.argv[2]) - if sys.argv[2] != '/dev/null' and stdoutfname == stderrfname: - stderr = open(stderrfname,'a') + stderrfname = translate_filename_for_os(sys.argv[2]) + if sys.argv[2] != "/dev/null" and stdoutfname == stderrfname: + stderr = open(stderrfname, "a") else: - stderr = open(stderrfname,'w') - - - stdin = open(translate_filename_for_os('/dev/null'),'r') - + stderr = open(stderrfname, "w") + + stdin = open(translate_filename_for_os("/dev/null"), "r") + returncode = call( - sys.argv[3:], - stdout = stdout, - stderr = stderr, - stdin = stdin, - close_fds = False + sys.argv[3:], stdout=stdout, stderr=stderr, stdin=stdin, close_fds=False ) stdin.close() - + if not stdout is None: stdout.close() - + if not stderr is None: stderr.close() - + sys.exit(returncode) diff --git a/src/amuse/rfi/slurm.py b/src/amuse/rfi/slurm.py index 616e137bba..4ee4f3abf2 100644 --- a/src/amuse/rfi/slurm.py +++ b/src/amuse/rfi/slurm.py @@ -1,48 +1,46 @@ - - - def parse_slurm_tasks_per_node(string): - per_node = string.split(',') + per_node = string.split(",") result = [] for node in per_node: - parts = node.split('(') + parts = node.split("(") count = parts[0] if len(parts) == 2: nodes = parts[1] else: nodes = None - + try: count = int(count) except: - count = 0 # unparsable number + count = 0 # unparsable number if nodes: - nodes = nodes[1:-1] # skip the 'x' character and the closing ')' character + nodes = nodes[1:-1] # skip the 'x' character and the closing ')' character try: nodes = int(nodes) except: - nodes = 1 # unparsable number, assume 1 + nodes = 1 # unparsable number, assume 1 for _ in range(nodes): result.append(count) else: result.append(count) return result - + + def parse_slurm_nodelist(string): result = [] - + name_characters = [] position = 0 while position < len(string): char = string[position] - if char == '[': - name = ''.join(name_characters) + if char == "[": + name = "".join(name_characters) ids, position = parse_ids(string, position) for x in ids: result.append(name + x) name_characters = [] - elif char == ',': - name = ''.join(name_characters) + elif char == ",": + name = "".join(name_characters) result.append(name) name_characters = [] position += 1 @@ -50,24 +48,26 @@ def parse_slurm_nodelist(string): name_characters.append(char) position += 1 if len(name_characters) > 0: - name = ''.join(name_characters) + name = "".join(name_characters) result.append(name) name_characters = [] return result - + + def parse_ids(string, position): result = [] - end = string.index(']',position) - count_ranges = string[position+1:end] - for count_range in count_ranges.split(','): - if '-' in count_range: - from_id, to_id = count_range.split('-') + end = string.index("]", position) + count_ranges = string[position + 1 : end] + for count_range in count_ranges.split(","): + if "-" in count_range: + from_id, to_id = count_range.split("-") for number in range(int(from_id), int(to_id) + 1): result.append(str(number)) else: result.append(count_range) - return result, end+1 - + return result, end + 1 + + if __name__ == "__main__": print(parse_slurm_tasks_per_node("10(x4),3")) print(parse_slurm_nodelist("tcn[595,597-598,600-606],tcn100")) diff --git a/src/amuse/rfi/tools/create_c.py b/src/amuse/rfi/tools/create_c.py index b92a9dc7af..16d9eb1114 100644 --- a/src/amuse/rfi/tools/create_c.py +++ b/src/amuse/rfi/tools/create_c.py @@ -8,20 +8,32 @@ from amuse.rfi.tools import create_definition from amuse.rfi.core import LegacyFunctionSpecification -dtype_to_spec = DTypeToSpecDictionary({ - 'int32' : DTypeSpec('ints_in', 'ints_out', - 'HEADER_INTEGER_COUNT', 'int', 'MPI_INT'), - 'int64' : DTypeSpec('longs_in', 'longs_out', - 'HEADER_LONG_COUNT', 'long long int', 'MPI_LONG_LONG_INT'), - 'float32' : DTypeSpec('floats_in', 'floats_out', - 'HEADER_FLOAT_COUNT', 'float', 'MPI_FLOAT'), - 'float64' : DTypeSpec('doubles_in', 'doubles_out', - 'HEADER_DOUBLE_COUNT', 'double', 'MPI_DOUBLE'), - 'bool' : DTypeSpec('booleans_in', 'booleans_out', - 'HEADER_BOOLEAN_COUNT', 'bool', 'MPI_C_BOOL'), - 'string' : DTypeSpec('strings_in', 'strings_out', - 'HEADER_STRING_COUNT', 'int', 'MPI_INTEGER'), -}) +dtype_to_spec = DTypeToSpecDictionary( + { + "int32": DTypeSpec( + "ints_in", "ints_out", "HEADER_INTEGER_COUNT", "int", "MPI_INT" + ), + "int64": DTypeSpec( + "longs_in", + "longs_out", + "HEADER_LONG_COUNT", + "long long int", + "MPI_LONG_LONG_INT", + ), + "float32": DTypeSpec( + "floats_in", "floats_out", "HEADER_FLOAT_COUNT", "float", "MPI_FLOAT" + ), + "float64": DTypeSpec( + "doubles_in", "doubles_out", "HEADER_DOUBLE_COUNT", "double", "MPI_DOUBLE" + ), + "bool": DTypeSpec( + "booleans_in", "booleans_out", "HEADER_BOOLEAN_COUNT", "bool", "MPI_C_BOOL" + ), + "string": DTypeSpec( + "strings_in", "strings_out", "HEADER_STRING_COUNT", "int", "MPI_INTEGER" + ), + } +) HEADER_CODE_STRING = """ #ifndef NOMPI @@ -1043,7 +1055,7 @@ """ -GETSET_WORKING_DIRECTORY=""" +GETSET_WORKING_DIRECTORY = """ char path_buffer[4096]; @@ -1062,402 +1074,441 @@ """ - - - class MakeCCodeString(GenerateASourcecodeString): @late def dtype_to_spec(self): return dtype_to_spec - - + class GenerateACStringOfAFunctionSpecification(MakeCCodeString): @late def specification(self): - raise exceptions.AmuseException("No specification set, please set the specification first") - - + raise exceptions.AmuseException( + "No specification set, please set the specification first" + ) + def start(self): - self.specification.prepare_output_parameters() self.output_casestmt_start() self.out.indent() - + if self.specification.must_handle_array: pass elif self.specification.can_handle_array: - self.out.lf() + 'for (int i = 0 ; i < call_count; i++){' + self.out.lf() + "for (int i = 0 ; i < call_count; i++){" self.out.indent() - + self.output_copy_inout_variables() self.output_function_start() self.output_function_parameters() self.output_function_end() - + if self.specification.must_handle_array: if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] - self.out.lf() + 'for (int i = 1 ; i < call_count; i++){' + self.out.lf() + "for (int i = 1 ; i < call_count; i++){" self.out.indent() - self.out.lf() + spec.output_var_name + '[i]' + ' = ' + spec.output_var_name + '[0]' + ';' + ( + self.out.lf() + + spec.output_var_name + + "[i]" + + " = " + + spec.output_var_name + + "[0]" + + ";" + ) self.out.dedent() - self.out.lf() + '}' + self.out.lf() + "}" elif self.specification.can_handle_array: self.out.dedent() - self.out.lf() + '}' - + self.out.lf() + "}" + self.output_lines_with_number_of_outputs() self.output_casestmt_end() self.out.dedent() self._result = self.out.string - + def index_string(self, index, must_copy_in_to_out=False): if self.specification.must_handle_array and not must_copy_in_to_out: if index == 0: - return '0' + return "0" else: - return '( %d * call_count)' % index - elif self.specification.can_handle_array or (self.specification.must_handle_array and must_copy_in_to_out): + return "( %d * call_count)" % index + elif self.specification.can_handle_array or ( + self.specification.must_handle_array and must_copy_in_to_out + ): if index == 0: - return 'i' + return "i" else: - return '( %d * call_count) + i' % index + return "( %d * call_count) + i" % index else: return index - - + def input_var(self, name, index): if self.specification.must_handle_array: self.output_var(name, index) else: self.out.n() + name - self.out + '[' + self.index_string(index) + ']' - + self.out + "[" + self.index_string(index) + "]" + def output_var(self, name, index): - self.out.n() + '&' + name - self.out + '[' + self.index_string(index) + ']' - + self.out.n() + "&" + name + self.out + "[" + self.index_string(index) + "]" + def output_function_parameters(self): self.out.indent() - + first = True - + for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if first: first = False else: - self.out + ' ,' - + self.out + " ," + if parameter.direction == LegacyFunctionSpecification.IN: - self.input_var(spec.input_var_name, parameter.input_index) + self.input_var(spec.input_var_name, parameter.input_index) if parameter.direction == LegacyFunctionSpecification.INOUT: - self.output_var(spec.output_var_name, parameter.output_index) + self.output_var(spec.output_var_name, parameter.output_index) elif parameter.direction == LegacyFunctionSpecification.OUT: - self.output_var(spec.output_var_name, parameter.output_index) + self.output_var(spec.output_var_name, parameter.output_index) elif parameter.direction == LegacyFunctionSpecification.LENGTH: - self.out.n() + 'call_count' - + self.out.n() + "call_count" + self.out.dedent() - - + def output_copy_inout_variables(self): for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if parameter.direction == LegacyFunctionSpecification.INOUT: if self.specification.must_handle_array: - self.out.lf() + 'for (int i = 0 ; i < call_count; i++){' + self.out.lf() + "for (int i = 0 ; i < call_count; i++){" self.out.indent() self.out.n() + spec.output_var_name - self.out + '[' + self.index_string(parameter.output_index, must_copy_in_to_out=True) + ']' - self.out + ' = ' - self.out + spec.input_var_name + '[' + self.index_string(parameter.input_index, must_copy_in_to_out=True) + ']' + ';' - + ( + self.out + + "[" + + self.index_string( + parameter.output_index, must_copy_in_to_out=True + ) + + "]" + ) + self.out + " = " + ( + self.out + + spec.input_var_name + + "[" + + self.index_string(parameter.input_index, must_copy_in_to_out=True) + + "]" + + ";" + ) + if self.specification.must_handle_array: self.out.dedent() - self.out.lf() + '}' - + self.out.lf() + "}" + def output_lines_with_number_of_outputs(self): dtype_to_count = {} - + for parameter in self.specification.output_parameters: count = dtype_to_count.get(parameter.datatype, 0) dtype_to_count[parameter.datatype] = count + 1 - + if not self.specification.result_type is None: count = dtype_to_count.get(self.specification.result_type, 0) dtype_to_count[self.specification.result_type] = count + 1 - - for dtype in dtype_to_count: + + for dtype in dtype_to_count: spec = self.dtype_to_spec[dtype] count = dtype_to_count[dtype] - self.out.n() - self.out + 'header_out[' + spec.counter_name - self.out + '] = ' + count + ' * call_count;' + self.out.n() + self.out + "header_out[" + spec.counter_name + self.out + "] = " + count + " * call_count;" pass - + def output_function_end(self): if len(self.specification.parameters) > 0: self.out.n() - - self.out + ')' + ';' - + + self.out + ")" + ";" + def output_function_start(self): - self.out.n() + self.out.n() if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] self.out + spec.output_var_name - self.out + '[' + self.index_string(0) + ']' + ' = ' - self.out + self.specification.name + '(' - + self.out + "[" + self.index_string(0) + "]" + " = " + self.out + self.specification.name + "(" + def output_casestmt_start(self): - self.out + 'case ' + self.specification.id + ':' - + self.out + "case " + self.specification.id + ":" + def output_casestmt_end(self): - self.out.n() + 'break;' - - + self.out.n() + "break;" + class GenerateACHeaderDefinitionStringFromAFunctionSpecification(MakeCCodeString): - - def start(self): self.output_function_start() self.output_function_parameters() self.output_function_end() self._result = self.out.string - - def output_function_parameters(self): + + def output_function_parameters(self): first = True - + for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if first: first = False else: - self.out + ', ' - - if parameter.datatype == 'string': - self.out + 'char' + self.out + ", " + + if parameter.datatype == "string": + self.out + "char" else: self.out + spec.type - self.out + ' ' - if parameter.is_output() or (parameter.is_input() and self.specification.must_handle_array): - self.out + '*' + ' ' - if parameter.datatype == 'string': - self.out + '*' + ' ' + self.out + " " + if parameter.is_output() or ( + parameter.is_input() and self.specification.must_handle_array + ): + self.out + "*" + " " + if parameter.datatype == "string": + self.out + "*" + " " self.out + parameter.name - - + def output_function_end(self): - self.out + ')' + ';' - + self.out + ")" + ";" + def output_function_start(self): self.out.n() if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] self.out + spec.type - self.out + ' ' + self.out + " " else: - self.out + 'void' + ' ' - self.out + self.specification.name + '(' - -class GenerateACSourcecodeStringFromASpecificationClass\ - (GenerateASourcecodeStringFromASpecificationClass): + self.out + "void" + " " + self.out + self.specification.name + "(" + +class GenerateACSourcecodeStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): @late def specification_class(self): - raise exceptions.AmuseException("No specification_class set, please set the specification_class first") - + raise exceptions.AmuseException( + "No specification_class set, please set the specification_class first" + ) + @late def dtype_to_spec(self): return dtype_to_spec def output_sourcecode_for_function(self): return GenerateACStringOfAFunctionSpecification() - + def start(self): self.out + HEADER_CODE_STRING self.output_local_includes() - + self.output_needs_mpi() - + self.output_code_constants() - + self.out.lf() + CONSTANTS_AND_GLOBAL_VARIABLES_STRING - + self.out.lf() + POLLING_FUNCTIONS_STRING self.out.lf() + GETSET_WORKING_DIRECTORY - + if self.must_generate_mpi: self.out.lf() + RECV_HEADER_SLEEP_STRING - + self.output_handle_call() - + self.out.lf() + FOOTER_CODE_STRING - + self._result = self.out.string - + def output_local_includes(self): - if hasattr(self.specification_class, 'include_headers'): + if hasattr(self.specification_class, "include_headers"): for x in self.specification_class.include_headers: self.out.n() + '#include "' + x + '"' self.out.lf() - def output_needs_mpi(self): if self.needs_mpi and self.must_generate_mpi: - self.out.lf() + 'static bool NEEDS_MPI = true;' + self.out.lf() + "static bool NEEDS_MPI = true;" else: - self.out.lf() + 'static bool NEEDS_MPI = false;' + self.out.lf() + "static bool NEEDS_MPI = false;" self.out.lf().lf() - + def output_code_constants(self): for dtype in self.dtype_to_spec.keys(): dtype_spec = self.dtype_to_spec[dtype] - - maxin = self.mapping_from_dtype_to_maximum_number_of_inputvariables.get(dtype, 0) - self.out + 'static int MAX_' + dtype_spec.input_var_name.upper() + ' = ' + maxin + ";" + + maxin = self.mapping_from_dtype_to_maximum_number_of_inputvariables.get( + dtype, 0 + ) + ( + self.out + + "static int MAX_" + + dtype_spec.input_var_name.upper() + + " = " + + maxin + + ";" + ) self.out.lf() - - maxout = self.mapping_from_dtype_to_maximum_number_of_outputvariables.get(dtype, 0) - self.out + 'static int MAX_' + dtype_spec.output_var_name.upper() + ' = ' + maxout + ";" + + maxout = self.mapping_from_dtype_to_maximum_number_of_outputvariables.get( + dtype, 0 + ) + ( + self.out + + "static int MAX_" + + dtype_spec.output_var_name.upper() + + " = " + + maxout + + ";" + ) self.out.lf() - + def output_handle_call(self): - self.out.lf().lf() + 'bool handle_call() {' + self.out.lf().lf() + "bool handle_call() {" self.out.indent() - - self.out.lf() + 'int call_count = header_in[HEADER_CALL_COUNT];' - - self.out.lf().lf() + 'switch(header_in[HEADER_FUNCTION_ID]) {' + + self.out.lf() + "int call_count = header_in[HEADER_CALL_COUNT];" + + self.out.lf().lf() + "switch(header_in[HEADER_FUNCTION_ID]) {" self.out.indent() - self.out.lf() + 'case 0:' - self.out.indent().lf() + 'return false;' - self.out.lf() + 'break;' + self.out.lf() + "case 0:" + self.out.indent().lf() + "return false;" + self.out.lf() + "break;" self.out.dedent() - + self.output_sourcecode_for_functions() - - self.out.lf() + 'default:' + + self.out.lf() + "default:" self.out.indent() - self.out.lf() + 'header_out[HEADER_FLAGS] = header_out[HEADER_FLAGS] | ERROR_FLAG;' - self.out.lf() + 'strings_out[0] = new char[100];' - self.out.lf() + 'sprintf(strings_out[0], "unknown function id: %d\\n", header_in[HEADER_FUNCTION_ID]);' - self.out.lf() + 'fprintf(stderr, "unknown function id: %d\\n", header_in[HEADER_FUNCTION_ID]);' - self.out.lf() + 'header_out[HEADER_STRING_COUNT] = 1;' + ( + self.out.lf() + + "header_out[HEADER_FLAGS] = header_out[HEADER_FLAGS] | ERROR_FLAG;" + ) + self.out.lf() + "strings_out[0] = new char[100];" + ( + self.out.lf() + + 'sprintf(strings_out[0], "unknown function id: %d\\n", header_in[HEADER_FUNCTION_ID]);' + ) + ( + self.out.lf() + + 'fprintf(stderr, "unknown function id: %d\\n", header_in[HEADER_FUNCTION_ID]);' + ) + self.out.lf() + "header_out[HEADER_STRING_COUNT] = 1;" self.out.dedent() - - self.out.dedent().lf() + '}' + + self.out.dedent().lf() + "}" self.out.dedent() - self.out.indent().lf() + 'return true;' - self.out.dedent().lf() + '}' + self.out.indent().lf() + "return true;" + self.out.dedent().lf() + "}" -class GenerateACHeaderStringFromASpecificationClass\ - (GenerateASourcecodeStringFromASpecificationClass): +class GenerateACHeaderStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): @late def ignore_functions_from_specification_classes(self): return [] - + @late def underscore_functions_from_specification_classes(self): return [] - + @late def dtype_to_spec(self): return dtype_to_spec - + @late def make_extern_c(self): return True - + def must_include_interface_function_in_output(self, x): - if hasattr(x.specification,"internal_provided"): - return False - + if hasattr(x.specification, "internal_provided"): + return False + for cls in self.ignore_functions_from_specification_classes: if hasattr(cls, x.specification.name): return False - + return True - + def output_sourcecode_for_function(self): return GenerateACHeaderDefinitionStringFromAFunctionSpecification() - + def start(self): - self.out + '#include "stdbool.h"' + self.out + '#include "stdbool.h"' self.out.lf() if self.make_extern_c: self.out + "#ifdef __cplusplus" self.out.lf() + 'extern "C" {' self.out.lf() + "#endif" self.out.lf() - + self.output_sourcecode_for_functions() - + if self.make_extern_c: self.out + "#ifdef __cplusplus" - self.out.lf() + '}' + self.out.lf() + "}" self.out.lf() + "#endif" self.out.lf() - + self.out.lf() - - self._result = self.out.string - + self._result = self.out.string -class GenerateACStubStringFromASpecificationClass\ - (GenerateASourcecodeStringFromASpecificationClass): +class GenerateACStubStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): @late def dtype_to_spec(self): return dtype_to_spec - + @late def make_extern_c(self): return False - + def output_sourcecode_for_function(self): return create_definition.CreateCStub() def must_include_interface_function_in_output(self, x): - return not hasattr(x.specification,"internal_provided") - - def start(self): - + return not hasattr(x.specification, "internal_provided") + + def start(self): self.output_local_includes() - + self.out.lf() - + if self.make_extern_c: self.out + 'extern "C" {' self.out.indent().lf() - + self.output_sourcecode_for_functions() - + if self.make_extern_c: - self.out.dedent().lf() + '}' - + self.out.dedent().lf() + "}" + self.out.lf() - + self._result = self.out.string - - + def output_local_includes(self): self.out.n() - if hasattr(self.specification_class, 'include_headers'): + if hasattr(self.specification_class, "include_headers"): for x in self.specification_class.include_headers: self.out.n() + '#include "' + x + '"' - - - - diff --git a/src/amuse/rfi/tools/create_code.py b/src/amuse/rfi/tools/create_code.py index 623a66ccc3..61a3e27d2c 100644 --- a/src/amuse/rfi/tools/create_code.py +++ b/src/amuse/rfi/tools/create_code.py @@ -5,59 +5,59 @@ from amuse.support.core import late, print_out from amuse.rfi.core import legacy_function + class DTypeSpec(object): - def __init__(self, input_var_name, output_var_name, counter_name, - type, mpi_type = 'UNKNOWN'): + def __init__( + self, input_var_name, output_var_name, counter_name, type, mpi_type="UNKNOWN" + ): self.input_var_name = input_var_name self.output_var_name = output_var_name self.counter_name = counter_name self.type = type self.mpi_type = mpi_type - - -dtypes = ['int32', 'int64', 'float32', 'float64', 'bool', 'string'] +dtypes = ["int32", "int64", "float32", "float64", "bool", "string"] + class GenerateASourcecodeString(object): _result = None def __init__(self): pass - - @late + + @late def result(self): if self._result is None: self.start() return self._result - + @late def out(self): return print_out() - + @late def must_generate_mpi(self): - if 'CFLAGS' in os.environ: - return not (os.environ['CFLAGS'].find('-DNOMPI') >= 0) + if "CFLAGS" in os.environ: + return not (os.environ["CFLAGS"].find("-DNOMPI") >= 0) else: return True class GenerateASourcecodeStringFromASpecificationClass(GenerateASourcecodeString): - @late def interface_functions(self): attribute_names = dir(self.specification_class) interface_functions = [] for x in attribute_names: - if x.startswith('__'): + if x.startswith("__"): continue value = getattr(self.specification_class, x) if isinstance(value, legacy_function): interface_functions.append(value) - interface_functions.sort(key= lambda x: x.specification.nspec) + interface_functions.sort(key=lambda x: x.specification.nspec) return interface_functions - + @late def mapping_from_dtype_to_maximum_number_of_inputvariables(self): result = None @@ -66,17 +66,16 @@ def mapping_from_dtype_to_maximum_number_of_inputvariables(self): for parameter in x.specification.input_parameters: count = local.get(parameter.datatype, 0) local[parameter.datatype] = count + 1 - - + if result is None: result = local else: for key, count in local.items(): previous_count = result.get(key, 0) result[key] = max(count, previous_count) - + return result - + @late def mapping_from_dtype_to_maximum_number_of_outputvariables(self): result = None @@ -85,54 +84,52 @@ def mapping_from_dtype_to_maximum_number_of_outputvariables(self): for parameter in x.specification.output_parameters: count = local.get(parameter.datatype, 0) local[parameter.datatype] = count + 1 - + if not x.specification.result_type is None: count = local.get(x.specification.result_type, 0) local[x.specification.result_type] = count + 1 - + if result is None: result = local else: for key, count in local.items(): previous_count = result.get(key, 0) result[key] = max(count, previous_count) - + return result - + def must_include_interface_function_in_output(self, x): return True - + def output_sourcecode_for_functions(self): for x in self.interface_functions: if x.specification.id == 0: continue if not self.must_include_interface_function_in_output(x): continue - + self.out.lf() uc = self.output_sourcecode_for_function() uc.specification = x.specification uc.out = self.out uc.start() self.out.lf() - + + class DTypeToSpecDictionary(object): - def __init__(self, dict): self.mapping = {} for datatype, value in dict.items(): self.mapping[datatype] = value - + def __getitem__(self, datatype): return self.mapping[datatype] - + def __len__(self): return len(self.mapping) - + def values(self): - return list(self.mapping.values()) # python3: maybe remove list - + return list(self.mapping.values()) # python3: maybe remove list + def keys(self): - return list(self.mapping.keys()) # python3: maybe remove list - - + return list(self.mapping.keys()) # python3: maybe remove list diff --git a/src/amuse/rfi/tools/create_definition.py b/src/amuse/rfi/tools/create_definition.py index d9417f7877..cb73fdb99f 100644 --- a/src/amuse/rfi/tools/create_definition.py +++ b/src/amuse/rfi/tools/create_definition.py @@ -2,33 +2,35 @@ from amuse.support.core import late, print_out + def strip_indent(string_with_indents): - return re.sub('^ *\n', '', string_with_indents.rstrip()) + return re.sub("^ *\n", "", string_with_indents.rstrip()) -class CodeDocStringProperty(object): +class CodeDocStringProperty(object): """ Return a docstring generated from a function specification """ + def __get__(self, instance, owner): if instance is None: if hasattr(owner, "__init__"): return owner.__init__.__doc__ else: return self - + usecase = CreateDescriptionOfAFunctionSpecification() usecase.specification = instance.specification usecase.start() return usecase.out.string -class CreateDescriptionOfAFunctionSpecification(object): +class CreateDescriptionOfAFunctionSpecification(object): @late def out(self): return print_out() - + def start(self): self.output_function_description() self.out.lf() @@ -39,14 +41,12 @@ def start(self): self.output_parameter_returntype() self.out.lf() - def output_function_description(self): self.output_multiline_string(self.specification.description) self.out.lf() - def output_multiline_string(self, string): - lines = string.split('\n') + lines = string.split("\n") first = True for line in lines: if first: @@ -54,129 +54,127 @@ def output_multiline_string(self, string): else: self.out.lf() self.out + line - + def output_cfunction_definition(self): - self.out + '.. code-block:: c' + self.out + ".. code-block:: c" self.out.indent() self.out.lf().lf() - + x = CreateCStub() x.convert_datatypes = False x.out = self.out x.specification = self.specification x.start() - + self.out.dedent() self.out.lf().lf() - + def output_fortran_function_definition(self): - - self.out + '.. code-block:: fortran' + self.out + ".. code-block:: fortran" self.out.indent() self.out.lf().lf() - + x = CreateFortranStub() x.out = self.out x.specification = self.specification x.start() - - + self.out.dedent() self.out.lf().lf() - + def output_parameter_descriptions(self): for parameter in self.specification.parameters: self.out.lf() - self.out + ':param ' + parameter.name + ': ' + self.out + ":param " + parameter.name + ": " self.out.indent() self.output_multiline_string(strip_indent(parameter.description)) self.out.dedent() self.out.lf() - self.out + ':type ' + parameter.name + ': ' - self.out + parameter.datatype + ', ' - self.output_parameter_direction(parameter) - + self.out + ":type " + parameter.name + ": " + self.out + parameter.datatype + ", " + self.output_parameter_direction(parameter) def output_parameter_direction(self, parameter): - #self.out + '(' + # self.out + '(' if parameter.direction == self.specification.IN: - self.out + 'IN' + self.out + "IN" if parameter.direction == self.specification.INOUT: - self.out + 'INOUT' + self.out + "INOUT" if parameter.direction == self.specification.OUT: - self.out + 'OUT' - #self.out + ')' + self.out + "OUT" + # self.out + ')' def output_parameter_returntype(self): if self.specification.result_type is None: return self.out.lf() - self.out + ':returns: ' + self.out + ":returns: " self.out.indent() self.output_multiline_string(strip_indent(self.specification.result_doc)) self.out.dedent() - + + class CreateInterfaceDefinitionDocument(object): @late def out(self): return print_out() - + def start(self): pass - + + class CreateFortranStub(object): @late def out(self): return print_out() - + def start(self): self.output_subprogram_start() self.output_parameter_type_definiton_lines() - - + if not self.output_definition_only: self.output_subprogram_content() self.output_subprogram_end() - - @late + + @late def specification_is_for_function(self): return not self.specification.result_type is None - + @late def output_definition_only(self): return True - - @late + + @late def subprogram_string(self): if self.specification_is_for_function: - return 'function' + return "function" else: - return 'subroutine' - + return "subroutine" + @late def dtype_to_parameters(self): result = {} - for parameter in self.specification.parameters: - parameters = result.get(parameter.datatype,[]) + for parameter in self.specification.parameters: + parameters = result.get(parameter.datatype, []) parameters.append(parameter) result[parameter.datatype] = parameters return result - + @late def dtype_to_fortrantype(self): return { - 'int32':'integer', - 'float64':'double precision', - 'float32':'real', - 'string':'character(len=*)', - 'bool':'logical', + "int32": "integer", + "float64": "double precision", + "float32": "real", + "string": "character(len=*)", + "bool": "logical", } - + def output_subprogram_start(self): - self.out + self.subprogram_string + ' ' + self.out + self.subprogram_string + " " self.out + self.specification.name - self.out + '(' + self.out + "(" self.out.indent() self.out.indent() first = True @@ -184,91 +182,99 @@ def output_subprogram_start(self): if first: first = False else: - self.out + ', ' - + self.out + ", " + length_of_the_argument_statement = len(parameter.name) - new_length_of_the_line = self.out.number_of_characters_on_current_line + length_of_the_argument_statement + new_length_of_the_line = ( + self.out.number_of_characters_on_current_line + + length_of_the_argument_statement + ) if new_length_of_the_line > 74: - self.out + ' &' + self.out + " &" self.out.lf() self.out + parameter.name - - self.out + ')' + + self.out + ")" self.out.dedent() - + if not self.output_definition_only: - self.out.lf() + 'implicit none' - + self.out.lf() + "implicit none" def output_parameter_type_definiton_lines(self): - for dtype,parameters in self.dtype_to_parameters.items(): + for dtype, parameters in self.dtype_to_parameters.items(): typestring = self.dtype_to_fortrantype[dtype] first = True - + self.out.lf() - self.out + typestring + ' :: ' - + self.out + typestring + " :: " + for parameter in parameters: - - length_of_the_argument_statement = len(parameter.name) - new_length_of_the_line = self.out.number_of_characters_on_current_line + length_of_the_argument_statement + new_length_of_the_line = ( + self.out.number_of_characters_on_current_line + + length_of_the_argument_statement + ) if new_length_of_the_line > 74: first = True self.out.lf() - self.out + typestring + ' :: ' - + self.out + typestring + " :: " + if first: first = False else: - self.out + ', ' - - self.out + parameter.name - + self.out + ", " + + self.out + parameter.name + self.output_function_type() - + def output_function_type(self): if self.specification_is_for_function: typestring = self.dtype_to_fortrantype[self.specification.result_type] self.out.lf() - self.out + typestring + ' :: ' + self.specification.name + self.out + typestring + " :: " + self.specification.name def output_subprogram_end(self): self.out.dedent() self.out.lf() - self.out + 'end ' + self.subprogram_string + self.out + "end " + self.subprogram_string def output_subprogram_content(self): if not self.specification.result_type is None: self.out.lf() - self.out + self.specification.name + '=' + self.dtype_to_returnvalue[self.specification.result_type] - + ( + self.out + + self.specification.name + + "=" + + self.dtype_to_returnvalue[self.specification.result_type] + ) + @late def dtype_to_returnvalue(self): return { - 'int32':'0', - 'float64':'0.0', - 'float32':'0.0', - 'string':'0', - 'bool':'0', + "int32": "0", + "float64": "0.0", + "float32": "0.0", + "string": "0", + "bool": "0", } - + class CreateCStub(object): @late def out(self): return print_out() - + def start(self): if self.specification.result_type is None: - self.out + 'void ' + self.out + "void " else: typestring = self.dtype_to_ctype[self.specification.result_type] self.out + typestring - self.out + ' ' - + self.out + " " + self.out + self.specification.name - self.out + '(' + self.out + "(" self.out.indent() first = True for parameter in self.specification.parameters: @@ -276,22 +282,25 @@ def start(self): if first: first = False else: - self.out + ', ' - + self.out + ", " + length_of_the_argument_statement = len(typestring) + len(parameter.name) + 3 - new_length_of_the_line = self.out.number_of_characters_on_current_line + length_of_the_argument_statement + new_length_of_the_line = ( + self.out.number_of_characters_on_current_line + + length_of_the_argument_statement + ) if new_length_of_the_line > 74: self.out.lf() self.out + typestring if parameter.is_output(): - self.out + ' *' - self.out + ' ' + parameter.name - + self.out + " *" + self.out + " " + parameter.name + self.out.dedent() - self.out + ')' - + self.out + ")" + if self.output_definition_only: - self.out + ';' + self.out + ";" else: self.output_function_content() @@ -301,61 +310,63 @@ def start(self): def result(self): self.start() return self._result - + def output_function_content(self): - self.out + '{' + self.out + "{" self.out.indent() if not self.specification.result_type is None: self.out.lf() - self.out + 'return ' + self.dtype_to_returnvalue[self.specification.result_type] + ';' + ( + self.out + + "return " + + self.dtype_to_returnvalue[self.specification.result_type] + + ";" + ) self.out.dedent().lf() - self.out + '}' - + self.out + "}" + @late def dtype_to_parameters(self): result = {} - for parameter in self.specification.parameters: - parameters = result.get(parameter.datatype,[]) + for parameter in self.specification.parameters: + parameters = result.get(parameter.datatype, []) parameters.append(parameter) result[parameter.datatype] = parameters return result - + @late def output_definition_only(self): - return (not self.convert_datatypes) - + return not self.convert_datatypes + @late def convert_datatypes(self): return True - + @late def dtype_to_ctype(self): if self.convert_datatypes: return { - 'int32':'int', - 'float64':'double', - 'float32':'float', - 'string':'char *', - 'bool':'_Bool', + "int32": "int", + "float64": "double", + "float32": "float", + "string": "char *", + "bool": "_Bool", } else: return { - 'int32':'int32', - 'float64':'float64', - 'float32':'float32', - 'string':'char *', - 'bool':'_Bool', + "int32": "int32", + "float64": "float64", + "float32": "float32", + "string": "char *", + "bool": "_Bool", } + @late def dtype_to_returnvalue(self): return { - 'int32':'0', - 'float64':'0.0', - 'float32':'0.0', - 'string':'0', - 'bool':'0', + "int32": "0", + "float64": "0.0", + "float32": "0.0", + "string": "0", + "bool": "0", } - - - - diff --git a/src/amuse/rfi/tools/create_dir.py b/src/amuse/rfi/tools/create_dir.py index 5b4c4cf154..19395f84f5 100644 --- a/src/amuse/rfi/tools/create_dir.py +++ b/src/amuse/rfi/tools/create_dir.py @@ -223,166 +223,173 @@ def test1(self): """ + class CreateADirectoryAndPopulateItWithFiles(OptionalAttributes): - @late def path_of_the_root_directory(self): return os.path.dirname(os.path.dirname(__file__)) - + @late def name_of_the_community_code(self): return self.name_of_the_code_interface_class.lower() - + @late def name_of_the_python_module(self): - return 'interface.py' - + return "interface.py" + @late def name_of_the_test_module(self): - return 'test_{0}.py'.format(self.name_of_the_community_code) - + return "test_{0}.py".format(self.name_of_the_community_code) + @late def name_of_the_interface_code(self): - return 'interface' - + return "interface" + @late def name_of_the_code_interface_class(self): - return 'MyCode' - + return "MyCode" + @late def name_of_the_community_interface_class(self): - return self.name_of_the_code_interface_class + 'Interface' - + return self.name_of_the_code_interface_class + "Interface" + @late def name_of_the_code_directory(self): - return 'src' - + return "src" + @late def name_for_import_of_the_interface_module(self): - return '.' + self.name_of_the_python_module[:-3] - + return "." + self.name_of_the_python_module[:-3] + @late def path_of_the_community_code(self): - return os.path.join(self.path_of_the_root_directory, self.name_of_the_community_code) - + return os.path.join( + self.path_of_the_root_directory, self.name_of_the_community_code + ) + @late def path_of_the_source_code(self): - return os.path.join(self.path_of_the_community_code, self.name_of_the_code_directory) - + return os.path.join( + self.path_of_the_community_code, self.name_of_the_code_directory + ) + @late def path_of_the_init_file(self): - return os.path.join(self.path_of_the_community_code, '__init__.py') - + return os.path.join(self.path_of_the_community_code, "__init__.py") + @late def path_of_the_interface_file(self): - return os.path.join(self.path_of_the_community_code, self.name_of_the_python_module) - + return os.path.join( + self.path_of_the_community_code, self.name_of_the_python_module + ) + @late def path_of_the_test_file(self): - return os.path.join(self.path_of_the_community_code, self.name_of_the_test_module) - + return os.path.join( + self.path_of_the_community_code, self.name_of_the_test_module + ) + @late def path_of_the_makefile(self): - return os.path.join(self.path_of_the_community_code, 'Makefile') - + return os.path.join(self.path_of_the_community_code, "Makefile") + @late def path_of_the_code_makefile(self): - return os.path.join(self.path_of_the_source_code, 'Makefile') - + return os.path.join(self.path_of_the_source_code, "Makefile") + @late def path_of_the_code_examplefile(self): raise NotImplementedError() - + @late def path_of_the_interface_examplefile(self): raise NotImplementedError() - + @late def path_of_amuse(self): return self.amuse_root_dir - + @late def reference_to_amuse_path(self): return os.path.relpath(self.path_of_amuse, self.path_of_the_community_code) - - @late + + @late def name_of_the_superclass_for_the_community_code_interface_class(self): return "CodeInterface" - + @late def name_of_the_superclass_for_the_code_interface_class(self): return "InCodeComponentImplementation" - + @late def amuse_root_dir(self): return get_amuse_root_dir() - + @late def include_headers_or_modules(self): return "include_headers = ['worker_code.h']" - + def start(self): - self.make_directories() self.make_python_files() self.make_makefile() self.make_example_files() - - + def make_directories(self): os.mkdir(self.path_of_the_community_code) os.mkdir(self.path_of_the_source_code) - + def make_python_files(self): with open(self.path_of_the_init_file, "w") as f: f.write("# generated file") - + with open(self.path_of_the_interface_file, "w") as f: string = interface_file_template.format(self) f.write(string) - + with open(self.path_of_the_test_file, "w") as f: string = test_file_template.format(self) f.write(string) - + def make_makefile(self): pass - + def make_example_files(self): pass - - -class CreateADirectoryAndPopulateItWithFilesForACCode(CreateADirectoryAndPopulateItWithFiles): - + + +class CreateADirectoryAndPopulateItWithFilesForACCode( + CreateADirectoryAndPopulateItWithFiles +): @late def path_of_the_code_examplefile(self): - return os.path.join(self.path_of_the_source_code, 'test.cc') - + return os.path.join(self.path_of_the_source_code, "test.cc") + @late def path_of_the_interface_examplefile(self): - return os.path.join(self.path_of_the_community_code, self.name_of_the_interface_code + '.cc') - + return os.path.join( + self.path_of_the_community_code, self.name_of_the_interface_code + ".cc" + ) + def make_makefile(self): - with open(self.path_of_the_makefile, "w") as f: string = makefile_template_cxx.format(self) f.write(string) - + def make_example_files(self): with open(self.path_of_the_code_makefile, "w") as f: string = code_makefile_template_cxx.format(self) f.write(string) - + with open(self.path_of_the_code_examplefile, "w") as f: string = code_examplefile_template_cxx f.write(string) - + with open(self.path_of_the_interface_examplefile, "w") as f: string = interface_examplefile_template_cxx f.write(string) - makefile_template_fortran = """\ # standard amuse configuration include # config.mk will be made after ./configure has run @@ -482,15 +489,20 @@ def make_example_files(self): end module """ -class CreateADirectoryAndPopulateItWithFilesForAFortranCode(CreateADirectoryAndPopulateItWithFiles): - + + +class CreateADirectoryAndPopulateItWithFilesForAFortranCode( + CreateADirectoryAndPopulateItWithFiles +): @late def path_of_the_code_examplefile(self): - return os.path.join(self.path_of_the_source_code, 'test.f90') - + return os.path.join(self.path_of_the_source_code, "test.f90") + @late def path_of_the_interface_examplefile(self): - return os.path.join(self.path_of_the_community_code, self.name_of_the_interface_code + '.f90') + return os.path.join( + self.path_of_the_community_code, self.name_of_the_interface_code + ".f90" + ) @late def include_headers_or_modules(self): @@ -498,24 +510,22 @@ def include_headers_or_modules(self): @late def name_of_the_interface_module(self): - return '{0}Interface'.format(self.name_of_the_community_code) + return "{0}Interface".format(self.name_of_the_community_code) def make_makefile(self): - with open(self.path_of_the_makefile, "w") as f: string = makefile_template_fortran.format(self) f.write(string) - + def make_example_files(self): with open(self.path_of_the_code_makefile, "w") as f: string = code_makefile_template_fortran.format(self) f.write(string) - + with open(self.path_of_the_code_examplefile, "w") as f: string = code_examplefile_template_fortran f.write(string) - + with open(self.path_of_the_interface_examplefile, "w") as f: string = interface_examplefile_template_fortran.format(self) f.write(string) - diff --git a/src/amuse/rfi/tools/create_fortran.py b/src/amuse/rfi/tools/create_fortran.py index edb4367511..3cb8d2d6e1 100644 --- a/src/amuse/rfi/tools/create_fortran.py +++ b/src/amuse/rfi/tools/create_fortran.py @@ -12,15 +12,28 @@ from amuse.rfi.core import LegacyFunctionSpecification - -dtype_to_spec = DTypeToSpecDictionary({ - 'int32' : DTypeSpec('integers_in','integers_out','HEADER_INTEGER_COUNT', 'integer', 'integer'), - 'int64' : DTypeSpec('longs_in', 'longs_out', 'HEADER_LONG_COUNT', 'integer*8', 'long'), - 'float32' : DTypeSpec('floats_in', 'floats_out', 'HEADER_FLOAT_COUNT', 'real*4', 'float'), - 'float64' : DTypeSpec('doubles_in', 'doubles_out', 'HEADER_DOUBLE_COUNT', 'real*8', 'double'), - 'bool' : DTypeSpec('booleans_in', 'booleans_out', 'HEADER_BOOLEAN_COUNT', 'logical', 'boolean'), - 'string' : DTypeSpec('strings_in', 'strings_out', 'HEADER_STRING_COUNT', 'integer*4', 'integer'), -}) +dtype_to_spec = DTypeToSpecDictionary( + { + "int32": DTypeSpec( + "integers_in", "integers_out", "HEADER_INTEGER_COUNT", "integer", "integer" + ), + "int64": DTypeSpec( + "longs_in", "longs_out", "HEADER_LONG_COUNT", "integer*8", "long" + ), + "float32": DTypeSpec( + "floats_in", "floats_out", "HEADER_FLOAT_COUNT", "real*4", "float" + ), + "float64": DTypeSpec( + "doubles_in", "doubles_out", "HEADER_DOUBLE_COUNT", "real*8", "double" + ), + "bool": DTypeSpec( + "booleans_in", "booleans_out", "HEADER_BOOLEAN_COUNT", "logical", "boolean" + ), + "string": DTypeSpec( + "strings_in", "strings_out", "HEADER_STRING_COUNT", "integer*4", "integer" + ), + } +) CONSTANTS_STRING = """ integer HEADER_FLAGS, HEADER_CALL_ID, HEADER_FUNCTION_ID, HEADER_CALL_COUNT, & @@ -1129,7 +1142,7 @@ end if """ -GETSET_WORKING_DIRECTORY=""" +GETSET_WORKING_DIRECTORY = """ function set_working_directory(directory) result(ret) {0} @@ -1148,313 +1161,405 @@ """ - class GenerateAFortranStringOfAFunctionSpecification(GenerateASourcecodeString): MAX_STRING_LEN = 256 - + @late def specification(self): - raise exceptions.AmuseException("No specification set, please set the specification first") - + raise exceptions.AmuseException( + "No specification set, please set the specification first" + ) + @late def underscore_functions_from_specification_classes(self): return [] - + @late def dtype_to_spec(self): return dtype_to_spec - - def index_string(self, index, must_copy_in_to_out = False): + + def index_string(self, index, must_copy_in_to_out=False): if self.specification.must_handle_array and not must_copy_in_to_out: if index == 0: - return '1' + return "1" else: - return '( %d * call_count) + 1' % (index ) - elif self.specification.can_handle_array or (self.specification.must_handle_array and must_copy_in_to_out): + return "( %d * call_count) + 1" % (index) + elif self.specification.can_handle_array or ( + self.specification.must_handle_array and must_copy_in_to_out + ): if index == 0: - return 'i' + return "i" else: if index == -1: return "i - 1" else: - return '( %d * call_count) + i' % index + return "( %d * call_count) + i" % index else: return index + 1 - - def start(self): + + def start(self): self.specification.prepare_output_parameters() - + self.output_casestmt_start() self.out.indent() - - #self.output_lines_before_with_clear_out_variables() - #self.output_lines_before_with_clear_input_variables() - + + # self.output_lines_before_with_clear_out_variables() + # self.output_lines_before_with_clear_input_variables() + if self.specification.must_handle_array: pass elif self.specification.can_handle_array: - self.out.lf() + 'do i = 1, call_count, 1' + self.out.lf() + "do i = 1, call_count, 1" self.out.indent() - - #self.output_lines_before_with_inout_variables() + + # self.output_lines_before_with_inout_variables() self.output_function_start() self.output_function_parameters() self.output_function_end() self.output_lines_with_inout_variables() - - + if self.specification.must_handle_array: if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] - self.out.lf() + 'DO i = 2, call_count' + self.out.lf() + "DO i = 2, call_count" self.out.indent() - self.out.lf() + spec.output_var_name + '(i)' + ' = ' + spec.output_var_name + '(1)' + ( + self.out.lf() + + spec.output_var_name + + "(i)" + + " = " + + spec.output_var_name + + "(1)" + ) self.out.dedent() - self.out.lf() + 'END DO' + self.out.lf() + "END DO" elif self.specification.can_handle_array: self.out.dedent() - self.out.lf() + 'end do' - + self.out.lf() + "end do" + self.output_lines_with_number_of_outputs() self.output_casestmt_end() self.out.dedent() self._result = self.out.string - + def output_function_parameters(self): self.out.indent() - + first = True - + for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if first: first = False - self.out + ' &' + self.out + " &" else: - self.out + ' ,&' - + self.out + " ,&" + if parameter.direction == LegacyFunctionSpecification.IN: -# if parameter.datatype == 'string': -# self.out.n() + 'input_characters(' -# self.out + '( (' + self.index_string(parameter.input_index) + ')* ' + self.MAX_STRING_LEN + ')' -# self.out + ':' + '(((' + self.index_string(parameter.input_index) + ')* ' + self.MAX_STRING_LEN + ') +' -# self.out + '(' + spec.input_var_name + '(' + self.index_string(parameter.input_index) + ')' + '-' -# self.out + 'get_offset(' + self.index_string(parameter.input_index) + ' - 1 , '+spec.input_var_name +') ))' -# self.out + ')' -# else: - if parameter.datatype == 'string': - self.out.n() + 'strings_in(' + self.index_string(parameter.input_index) + ')' + # if parameter.datatype == 'string': + # self.out.n() + 'input_characters(' + # self.out + '( (' + self.index_string(parameter.input_index) + ')* ' + self.MAX_STRING_LEN + ')' + # self.out + ':' + '(((' + self.index_string(parameter.input_index) + ')* ' + self.MAX_STRING_LEN + ') +' + # self.out + '(' + spec.input_var_name + '(' + self.index_string(parameter.input_index) + ')' + '-' + # self.out + 'get_offset(' + self.index_string(parameter.input_index) + ' - 1 , '+spec.input_var_name +') ))' + # self.out + ')' + # else: + if parameter.datatype == "string": + ( + self.out.n() + + "strings_in(" + + self.index_string(parameter.input_index) + + ")" + ) else: - self.out.n() + spec.input_var_name - self.out + '(' + self.index_string(parameter.input_index) + ')' + self.out.n() + spec.input_var_name + self.out + "(" + self.index_string(parameter.input_index) + ")" if parameter.direction == LegacyFunctionSpecification.INOUT: -# if parameter.datatype == 'string': -# self.out.n() + 'output_characters(' -# self.out + '((' + self.index_string(parameter.output_index) + ')* ' + self.MAX_STRING_LEN + ')' -# self.out + ':' + '(((' + self.index_string(parameter.output_index) + ')+1) * ' + self.MAX_STRING_LEN + ' - 1)' -# self.out + ')' -# else: -# if parameter.datatype == 'string': -# self.out.n() + spec.input_var_name -# self.out + '(' + self.index_string(parameter.input_index) + ', :)' -# else: - self.out.n() + spec.input_var_name - self.out + '(' + self.index_string(parameter.input_index) + ')' + # if parameter.datatype == 'string': + # self.out.n() + 'output_characters(' + # self.out + '((' + self.index_string(parameter.output_index) + ')* ' + self.MAX_STRING_LEN + ')' + # self.out + ':' + '(((' + self.index_string(parameter.output_index) + ')+1) * ' + self.MAX_STRING_LEN + ' - 1)' + # self.out + ')' + # else: + # if parameter.datatype == 'string': + # self.out.n() + spec.input_var_name + # self.out + '(' + self.index_string(parameter.input_index) + ', :)' + # else: + self.out.n() + spec.input_var_name + self.out + "(" + self.index_string(parameter.input_index) + ")" elif parameter.direction == LegacyFunctionSpecification.OUT: -# if parameter.datatype == 'string': -# self.out.n() + 'output_characters(' -# self.out + '((' + self.index_string(parameter.output_index) + ')* ' + self.MAX_STRING_LEN + ')' -# self.out + ':' + '(((' + self.index_string(parameter.output_index) + ')+1) * ' + self.MAX_STRING_LEN + ' - 1)' -# self.out + ')' -# else: -# if parameter.datatype == 'string': -# self.out.n() + spec.output_var_name -# self.out + '(' + self.index_string(parameter.output_index) + ')(1:50)' -# else: - self.out.n() + spec.output_var_name - self.out + '(' + self.index_string(parameter.output_index) + ')' + # if parameter.datatype == 'string': + # self.out.n() + 'output_characters(' + # self.out + '((' + self.index_string(parameter.output_index) + ')* ' + self.MAX_STRING_LEN + ')' + # self.out + ':' + '(((' + self.index_string(parameter.output_index) + ')+1) * ' + self.MAX_STRING_LEN + ' - 1)' + # self.out + ')' + # else: + # if parameter.datatype == 'string': + # self.out.n() + spec.output_var_name + # self.out + '(' + self.index_string(parameter.output_index) + ')(1:50)' + # else: + self.out.n() + spec.output_var_name + self.out + "(" + self.index_string(parameter.output_index) + ")" elif parameter.direction == LegacyFunctionSpecification.LENGTH: - self.out.n() + 'call_count' - + self.out.n() + "call_count" + self.out.dedent() - + def output_lines_with_inout_variables(self): - for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if parameter.direction == LegacyFunctionSpecification.INOUT: if self.specification.must_handle_array: - self.out.lf() + 'DO i = 1, call_count' - self.out.indent() - - self.out.n() + spec.output_var_name - self.out + '(' + self.index_string(parameter.output_index, must_copy_in_to_out = True) + ')' - self.out + ' = ' - self.out + spec.input_var_name + '(' + self.index_string(parameter.input_index, must_copy_in_to_out = True) + ')' - + self.out.lf() + "DO i = 1, call_count" + self.out.indent() + + self.out.n() + spec.output_var_name + ( + self.out + + "(" + + self.index_string( + parameter.output_index, must_copy_in_to_out=True + ) + + ")" + ) + self.out + " = " + ( + self.out + + spec.input_var_name + + "(" + + self.index_string(parameter.input_index, must_copy_in_to_out=True) + + ")" + ) + if self.specification.must_handle_array: - self.out.dedent() - self.out.lf() + 'END DO' - + self.out.dedent() + self.out.lf() + "END DO" + def output_lines_before_with_clear_out_variables(self): for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if parameter.is_output(): - if parameter.datatype == 'string': - self.out.lf() + 'output_characters = "x"' + if parameter.datatype == "string": + self.out.lf() + 'output_characters = "x"' return - + def output_lines_before_with_clear_input_variables(self): for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if parameter.is_input(): - if parameter.datatype == 'string': - self.out.lf() + 'input_characters = "x"' + if parameter.datatype == "string": + self.out.lf() + 'input_characters = "x"' return - - - + def output_lines_before_with_inout_variables(self): - for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - - + if parameter.direction == LegacyFunctionSpecification.IN: - if parameter.datatype == 'string': - self.out.n() + 'input_characters(' - self.out + '( (' + self.index_string(parameter.input_index) + ')* ' + self.MAX_STRING_LEN + ')' - self.out + ':' + '(((' + self.index_string(parameter.input_index) + ')+1) * ' + self.MAX_STRING_LEN + ' - 1)' - self.out + ') = &' + if parameter.datatype == "string": + self.out.n() + "input_characters(" + ( + self.out + + "( (" + + self.index_string(parameter.input_index) + + ")* " + + self.MAX_STRING_LEN + + ")" + ) + ( + self.out + + ":" + + "(((" + + self.index_string(parameter.input_index) + + ")+1) * " + + self.MAX_STRING_LEN + + " - 1)" + ) + self.out + ") = &" self.out.lf() - self.out + 'characters(' - self.out + 'get_offset(' + self.index_string(parameter.input_index) + ' - 1 , '+spec.input_var_name +')' - self.out + ':' + spec.input_var_name + '(' + self.index_string(parameter.input_index) + ')' - self.out + ')' - + self.out + "characters(" + ( + self.out + + "get_offset(" + + self.index_string(parameter.input_index) + + " - 1 , " + + spec.input_var_name + + ")" + ) + ( + self.out + + ":" + + spec.input_var_name + + "(" + + self.index_string(parameter.input_index) + + ")" + ) + self.out + ")" + if parameter.direction == LegacyFunctionSpecification.INOUT: - if parameter.datatype == 'string': - self.out.n() + 'output_characters(' - self.out + '( (' + self.index_string(parameter.output_index) + ')* ' + self.MAX_STRING_LEN + ')' - self.out + ':' + '(((' + self.index_string(parameter.output_index) + ')+1) * ' + self.MAX_STRING_LEN + ' - 1)' - self.out + ') = &' + if parameter.datatype == "string": + self.out.n() + "output_characters(" + ( + self.out + + "( (" + + self.index_string(parameter.output_index) + + ")* " + + self.MAX_STRING_LEN + + ")" + ) + ( + self.out + + ":" + + "(((" + + self.index_string(parameter.output_index) + + ")+1) * " + + self.MAX_STRING_LEN + + " - 1)" + ) + self.out + ") = &" self.out.lf() - self.out + 'characters(' - self.out + 'get_offset(' + self.index_string(parameter.input_index) + ' - 1 , '+spec.input_var_name +')' - self.out + ':' + spec.input_var_name + '(' + self.index_string(parameter.input_index) + ')' - self.out + ')' - + self.out + "characters(" + ( + self.out + + "get_offset(" + + self.index_string(parameter.input_index) + + " - 1 , " + + spec.input_var_name + + ")" + ) + ( + self.out + + ":" + + spec.input_var_name + + "(" + + self.index_string(parameter.input_index) + + ")" + ) + self.out + ")" + def output_lines_with_number_of_outputs(self): dtype_to_count = {} - + for parameter in self.specification.output_parameters: count = dtype_to_count.get(parameter.datatype, 0) dtype_to_count[parameter.datatype] = count + 1 - + if not self.specification.result_type is None: count = dtype_to_count.get(self.specification.result_type, 0) dtype_to_count[self.specification.result_type] = count + 1 - - for dtype in dtype_to_count: + + for dtype in dtype_to_count: spec = self.dtype_to_spec[dtype] count = dtype_to_count[dtype] - self.out.n() + 'header_out(' + spec.counter_name + ') = ' + count + ' * call_count' + ( + self.out.n() + + "header_out(" + + spec.counter_name + + ") = " + + count + + " * call_count" + ) pass - + def output_function_end(self): - self.out + ' &' - self.out.n() + ')' - + self.out + " &" + self.out.n() + ")" + def output_function_start(self): - self.out.n() + self.out.n() if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] -# if self.specification.result_type == 'string': -# self.out + 'output_characters(' -# self.out + '( (' + self.index_string(0) + ')* ' + self.MAX_STRING_LEN + ')' -# self.out + ':' + '(((' + self.index_string(0) + ')+1)*' + self.MAX_STRING_LEN + '-1)' -# self.out + ') = &' -# self.out.lf() -# else: + # if self.specification.result_type == 'string': + # self.out + 'output_characters(' + # self.out + '( (' + self.index_string(0) + ')* ' + self.MAX_STRING_LEN + ')' + # self.out + ':' + '(((' + self.index_string(0) + ')+1)*' + self.MAX_STRING_LEN + '-1)' + # self.out + ') = &' + # self.out.lf() + # else: self.out + spec.output_var_name - self.out + '(' + self.index_string(0) + ')' + ' = ' - else: - self.out + 'CALL ' - self.out + self.specification.name + self.out + "(" + self.index_string(0) + ")" + " = " + else: + self.out + "CALL " + self.out + self.specification.name if self.must_add_underscore_to_function(self.specification): - self.out + '_' - self.out + '(' - + self.out + "_" + self.out + "(" + def output_casestmt_start(self): - self.out + 'CASE(' + self.specification.id + ')' - + self.out + "CASE(" + self.specification.id + ")" + def output_casestmt_end(self): - self.out.n() - + self.out.n() + def must_add_underscore_to_function(self, x): - for cls in self.underscore_functions_from_specification_classes: if hasattr(cls, x.name): return True - + return False - - -class GenerateAFortranSourcecodeStringFromASpecificationClass(GenerateASourcecodeStringFromASpecificationClass): + + +class GenerateAFortranSourcecodeStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): MAX_STRING_LEN = 256 @late def dtype_to_spec(self): - return dtype_to_spec - + return dtype_to_spec + @late def number_of_types(self): return len(self.dtype_to_spec) - + @late def length_of_the_header(self): return 2 + self.number_of_types - + @late def underscore_functions_from_specification_classes(self): return [] - + def output_sourcecode_for_function(self): result = GenerateAFortranStringOfAFunctionSpecification() - result.underscore_functions_from_specification_classes = self.underscore_functions_from_specification_classes + result.underscore_functions_from_specification_classes = ( + self.underscore_functions_from_specification_classes + ) return result - + def output_needs_mpi(self): - self.out.lf() + 'logical NEEDS_MPI' - - if (hasattr(self, 'needs_mpi') and self.needs_mpi) and self.must_generate_mpi: - self.out.lf() + 'parameter (NEEDS_MPI=.true.)' + self.out.lf() + "logical NEEDS_MPI" + + if (hasattr(self, "needs_mpi") and self.needs_mpi) and self.must_generate_mpi: + self.out.lf() + "parameter (NEEDS_MPI=.true.)" else: - self.out.lf() + 'parameter (NEEDS_MPI=.false.)' - + self.out.lf() + "parameter (NEEDS_MPI=.false.)" + self.out.lf().lf() - + def start(self): self.use_iso_c_bindings = config.compilers.fc_iso_c_bindings - self.out + GETSET_WORKING_DIRECTORY.format("" if not config.compilers.ifort_version else " use ifport") + self.out + GETSET_WORKING_DIRECTORY.format( + "" if not config.compilers.ifort_version else " use ifport" + ) - self.out + 'program amuse_worker_program' + self.out + "program amuse_worker_program" self.out.indent() - + self.output_modules() - - if self.use_iso_c_bindings: - self.out.n() + 'use iso_c_binding' - - self.out.n() + 'implicit none' + + if self.use_iso_c_bindings: + self.out.n() + "use iso_c_binding" + + self.out.n() + "implicit none" self.out.n() + CONSTANTS_STRING - + self.output_needs_mpi() self.output_maximum_constants() @@ -1463,33 +1568,34 @@ def start(self): self.out.lf().lf() + MODULE_GLOBALS_STRING else: self.out.lf().lf() + NOMPI_MODULE_GLOBALS_STRING - + if self.use_iso_c_bindings: self.out.n() + ISO_ARRAY_DEFINES_STRING else: self.out.n() + ARRAY_DEFINES_STRING - - + self.out.lf().lf() + MAIN_STRING - - self.out.lf().lf() + 'CONTAINS' - + + self.out.lf().lf() + "CONTAINS" + self.out + POLLING_FUNCTIONS_STRING - self.out + GETSET_WORKING_DIRECTORY.format("" if not config.compilers.ifort_version else " use ifport") + self.out + GETSET_WORKING_DIRECTORY.format( + "" if not config.compilers.ifort_version else " use ifport" + ) if self.must_generate_mpi: self.out + INTERNAL_FUNCTIONS_STRING - - if self.use_iso_c_bindings: + + if self.use_iso_c_bindings: self.out + RECV_HEADER_SLEEP_STRING else: self.out + RECV_HEADER_WAIT_STRING - + self.out + RUN_LOOP_MPI_STRING else: self.out + NOMPI_INTERNAL_FUNCTIONS_STRING - + self.out + EMPTY_RUN_LOOP_MPI_STRING if self.use_iso_c_bindings: self.out.n() + RUN_LOOP_SOCKETS_STRING @@ -1501,180 +1607,188 @@ def start(self): else: self.out.n() + EMPTY_RUN_LOOP_SOCKETS_STRING self.out.n() + EMPTY_RUN_LOOP_SOCKETS_MPI_STRING - + self.output_handle_call() self.out.dedent() - self.out.n() + 'end program amuse_worker_program' + self.out.n() + "end program amuse_worker_program" self._result = self.out.string def output_mpi_include(self): self.out.n() + "USE mpi" - + def output_modules(self): self.out.n() - if hasattr(self.specification_class, 'use_modules'): + if hasattr(self.specification_class, "use_modules"): for x in self.specification_class.use_modules: - self.out.n() + 'use ' + x - + self.out.n() + "use " + x + def must_include_declaration_of_function(self, x): - if hasattr(x.specification,"internal_provided"): + if hasattr(x.specification, "internal_provided"): return False - + return True - - + def output_declarations_for_the_functions(self): - if not hasattr(self.specification_class, 'use_modules'): + if not hasattr(self.specification_class, "use_modules"): for x in self.interface_functions: if not self.must_include_declaration_of_function(x): continue - + specification = x.specification if specification.id == 0: continue if specification.result_type is None: continue - if specification.result_type == 'string': - type = 'character(len=255)' + if specification.result_type == "string": + type = "character(len=255)" else: spec = self.dtype_to_spec[specification.result_type] type = spec.type - self.out.lf() + type + ' :: ' + specification.name - + self.out.lf() + type + " :: " + specification.name + if self.must_add_underscore_to_function(x): - self.out + '_' - - + self.out + "_" + def must_add_underscore_to_function(self, x): - for cls in self.underscore_functions_from_specification_classes: if hasattr(cls, x.specification.name): return True - + return False - + def output_handle_call(self): - - self.out.lf() + 'integer function handle_call()' + self.out.lf() + "integer function handle_call()" self.out.indent().n() - self.out.lf() + 'implicit none' + self.out.lf() + "implicit none" - self.output_declarations_for_the_functions() - - self.out.lf() + 'integer i, call_count' - self.out.lf() + 'call_count = header_in(HEADER_CALL_COUNT)' - self.out.lf() + 'handle_call = 1' - self.out.lf() + 'SELECT CASE (header_in(HEADER_FUNCTION_ID))' + + self.out.lf() + "integer i, call_count" + self.out.lf() + "call_count = header_in(HEADER_CALL_COUNT)" + self.out.lf() + "handle_call = 1" + self.out.lf() + "SELECT CASE (header_in(HEADER_FUNCTION_ID))" self.out.indent().n() - self.out.lf() + 'CASE(0)' - self.out.indent().lf()+'handle_call = 0' + self.out.lf() + "CASE(0)" + self.out.indent().lf() + "handle_call = 0" self.out.dedent() - + self.output_sourcecode_for_functions() - self.out.lf() + 'CASE DEFAULT' + self.out.lf() + "CASE DEFAULT" self.out.indent() - self.out.lf() + 'header_out(HEADER_STRING_COUNT) = 1' - self.out.lf() + 'header_out(HEADER_FLAGS) = IOR(header_out(HEADER_FLAGS), 256) ' - self.out.lf() + "strings_out(1) = 'error, illegal function id'" + self.out.lf() + "header_out(HEADER_STRING_COUNT) = 1" + self.out.lf() + "header_out(HEADER_FLAGS) = IOR(header_out(HEADER_FLAGS), 256) " + self.out.lf() + "strings_out(1) = 'error, illegal function id'" self.out.dedent() - - self.out.dedent().n() + 'END SELECT' - self.out.n() + 'return' + self.out.dedent().n() + "END SELECT" + + self.out.n() + "return" self.out.dedent() - self.out.n() + 'end function' - + self.out.n() + "end function" + def output_maximum_constants(self): - - self.out.lf() + 'integer MAX_INTEGERS_IN, MAX_INTEGERS_OUT, MAX_LONGS_IN, MAX_LONGS_OUT, &' - self.out.lf() + 'MAX_FLOATS_IN, MAX_FLOATS_OUT, MAX_DOUBLES_IN,MAX_DOUBLES_OUT, &' - self.out.lf() + 'MAX_BOOLEANS_IN,MAX_BOOLEANS_OUT, MAX_STRINGS_IN, MAX_STRINGS_OUT' + ( + self.out.lf() + + "integer MAX_INTEGERS_IN, MAX_INTEGERS_OUT, MAX_LONGS_IN, MAX_LONGS_OUT, &" + ) + ( + self.out.lf() + + "MAX_FLOATS_IN, MAX_FLOATS_OUT, MAX_DOUBLES_IN,MAX_DOUBLES_OUT, &" + ) + ( + self.out.lf() + + "MAX_BOOLEANS_IN,MAX_BOOLEANS_OUT, MAX_STRINGS_IN, MAX_STRINGS_OUT" + ) self.out.lf() for dtype in self.dtype_to_spec.keys(): dtype_spec = self.dtype_to_spec[dtype] - maximum = self.mapping_from_dtype_to_maximum_number_of_inputvariables.get(dtype,0) - - self.out.n() + 'parameter (MAX_' + dtype_spec.input_var_name.upper() + '=' + maximum + ')' - - maximum =self.mapping_from_dtype_to_maximum_number_of_outputvariables.get(dtype,0) - - self.out.n() + 'parameter (MAX_' + dtype_spec.output_var_name.upper() + '=' + maximum + ')' - - -class GenerateAFortranStubStringFromASpecificationClass\ - (GenerateASourcecodeStringFromASpecificationClass): - + maximum = self.mapping_from_dtype_to_maximum_number_of_inputvariables.get( + dtype, 0 + ) + + ( + self.out.n() + + "parameter (MAX_" + + dtype_spec.input_var_name.upper() + + "=" + + maximum + + ")" + ) + + maximum = self.mapping_from_dtype_to_maximum_number_of_outputvariables.get( + dtype, 0 + ) + + ( + self.out.n() + + "parameter (MAX_" + + dtype_spec.output_var_name.upper() + + "=" + + maximum + + ")" + ) + + +class GenerateAFortranStubStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): @late def dtype_to_spec(self): return dtype_to_spec - + @late def ignore_functions_from_specification_classes(self): return [] - + @late def underscore_functions_from_specification_classes(self): return [] - + def output_sourcecode_for_function(self): result = create_definition.CreateFortranStub() result.output_definition_only = False return result - - def start(self): - if hasattr(self.specification_class, 'use_modules'): - self.out.lf() + 'module {0}'.format(self.specification_class.use_modules[0]) - - self.out.indent() - + def start(self): + if hasattr(self.specification_class, "use_modules"): + self.out.lf() + "module {0}".format(self.specification_class.use_modules[0]) + + self.out.indent() + self.output_modules(1) - - if hasattr(self.specification_class, 'use_modules'): - self.out.lf() + "contains" + + if hasattr(self.specification_class, "use_modules"): + self.out.lf() + "contains" self.out.lf() - + self.output_sourcecode_for_functions() - + self.out.lf() - if hasattr(self.specification_class, 'use_modules'): + if hasattr(self.specification_class, "use_modules"): self.out.dedent() self.out.lf() + "end module" self.out.lf() - + self._result = self.out.string - - + def must_include_interface_function_in_output(self, x): - if hasattr(x.specification,"internal_provided"): + if hasattr(x.specification, "internal_provided"): return False - + for cls in self.ignore_functions_from_specification_classes: if hasattr(cls, x.specification.name): return False - + return True - - def output_modules(self,skip=0): + + def output_modules(self, skip=0): self.out.n() - if hasattr(self.specification_class, 'use_modules'): + if hasattr(self.specification_class, "use_modules"): for x in self.specification_class.use_modules[skip:]: - self.out.n() + 'use ' + x - - - - - - - - - - - + self.out.n() + "use " + x diff --git a/src/amuse/rfi/tools/create_java.py b/src/amuse/rfi/tools/create_java.py index 05b5f3c6f7..7244c74f7a 100644 --- a/src/amuse/rfi/tools/create_java.py +++ b/src/amuse/rfi/tools/create_java.py @@ -11,19 +11,16 @@ import os import inspect -dtype_to_spec = DTypeToSpecDictionary({ - 'int32' : DTypeSpec('Int', 'Int', '', 'int', ''), - 'int64' : DTypeSpec('Long', 'Long', - '', 'long', ''), - 'float32' : DTypeSpec('Float', 'Float', - '', 'float', ''), - 'float64' : DTypeSpec('Double', 'Double', - '', 'double', ''), - 'bool' : DTypeSpec('Boolean', 'Boolean', - '', 'boolean', ''), - 'string' : DTypeSpec('String', 'String', - '', 'String', ''), -}) +dtype_to_spec = DTypeToSpecDictionary( + { + "int32": DTypeSpec("Int", "Int", "", "int", ""), + "int64": DTypeSpec("Long", "Long", "", "long", ""), + "float32": DTypeSpec("Float", "Float", "", "float", ""), + "float64": DTypeSpec("Double", "Double", "", "double", ""), + "bool": DTypeSpec("Boolean", "Boolean", "", "boolean", ""), + "string": DTypeSpec("String", "String", "", "String", ""), + } +) IMPORTS_CODE_STRING = """ import java.io.IOException; @@ -1075,7 +1072,6 @@ """ - FOOTER_CODE_STRING = """ private final AmuseMessage request; private final AmuseMessage reply; @@ -1160,21 +1156,22 @@ """ + class MakeJavaCodeString(GenerateASourcecodeString): @late def dtype_to_spec(self): return dtype_to_spec - - + class GenerateAJavaStringOfAFunctionSpecification(MakeJavaCodeString): @late def specification(self): - raise exceptions.AmuseException("No specification set, please set the specification first") - - + raise exceptions.AmuseException( + "No specification set, please set the specification first" + ) + def start(self): - #must and can handle array is the same thing in Java codes... + # must and can handle array is the same thing in Java codes... if self.specification.can_handle_array: self.specification.must_handle_array = True @@ -1183,12 +1180,11 @@ def start(self): self.out.indent() self.out.lf() + "{" self.out.indent() - + self.output_lines_with_number_of_outputs() - - - if hasattr(self.specification,"internal_provided"): - self.out.lf() + "//" + self.specification.name + " ignored" + + if hasattr(self.specification, "internal_provided"): + self.out.lf() + "//" + self.specification.name + " ignored" else: self.output_declare_variables() self.output_function_start() @@ -1201,280 +1197,369 @@ def start(self): self.output_casestmt_end() self.out.dedent() self._result = self.out.string - + def output_casestmt_start(self): - self.out + 'case ' + self.specification.id + ':' - - + self.out + "case " + self.specification.id + ":" + def output_lines_with_number_of_outputs(self): dtype_to_count = {} - + for parameter in self.specification.output_parameters: count = dtype_to_count.get(parameter.datatype, 0) dtype_to_count[parameter.datatype] = count + 1 - + if not self.specification.result_type is None: count = dtype_to_count.get(self.specification.result_type, 0) dtype_to_count[self.specification.result_type] = count + 1 - - for dtype in dtype_to_count: + + for dtype in dtype_to_count: spec = self.dtype_to_spec[dtype] count = dtype_to_count[dtype] - self.out.lf() + 'reply.set' + spec.input_var_name + 'Count(' + count + ' * count);' + ( + self.out.lf() + + "reply.set" + + spec.input_var_name + + "Count(" + + count + + " * count);" + ) pass - - self.out.lf() + 'reply.ensurePrimitiveCapacity();' - - + + self.out.lf() + "reply.ensurePrimitiveCapacity();" + def output_function_parameters(self): self.out.indent() - + first = True - + for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if first: first = False else: - self.out + ', ' - + self.out + ", " + if parameter.direction == LegacyFunctionSpecification.IN: if self.specification.must_handle_array: self.out + parameter.name else: - self.out + parameter.name + '[0]' + self.out + parameter.name + "[0]" if parameter.direction == LegacyFunctionSpecification.INOUT: - self.out + parameter.name + self.out + parameter.name elif parameter.direction == LegacyFunctionSpecification.OUT: - self.out + parameter.name + self.out + parameter.name elif parameter.direction == LegacyFunctionSpecification.LENGTH: - self.out + 'count' - + self.out + "count" + self.out.dedent() def output_declare_variables(self): if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] - self.out.lf() + spec.type + ' functionResult;' - + self.out.lf() + spec.type + " functionResult;" + for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - - if parameter.direction == LegacyFunctionSpecification.IN or parameter.direction == LegacyFunctionSpecification.INOUT : - self.out.lf() + spec.type + '[] ' + parameter.name + ' = request.get' + spec.input_var_name + 'Slice(' + parameter.input_index + ');' + + if ( + parameter.direction == LegacyFunctionSpecification.IN + or parameter.direction == LegacyFunctionSpecification.INOUT + ): + ( + self.out.lf() + + spec.type + + "[] " + + parameter.name + + " = request.get" + + spec.input_var_name + + "Slice(" + + parameter.input_index + + ");" + ) if parameter.direction == LegacyFunctionSpecification.OUT: - self.out.lf() + spec.type + '[] ' + parameter.name + ' = new ' + spec.type + '[count];' + ( + self.out.lf() + + spec.type + + "[] " + + parameter.name + + " = new " + + spec.type + + "[count];" + ) def output_function_start(self): - self.out.n() + self.out.n() if not self.specification.result_type is None: - self.out + 'functionResult = ' - - self.out + 'code.' + self.specification.name + '(' + self.out + "functionResult = " + + self.out + "code." + self.specification.name + "(" def output_function_end(self): - self.out + ')' + ';' - + self.out + ")" + ";" + def output_copy_output_variables(self): if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] - self.out.lf() + 'reply.set' + spec.output_var_name + 'Slice(0, functionResult);' + ( + self.out.lf() + + "reply.set" + + spec.output_var_name + + "Slice(0, functionResult);" + ) for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - - if parameter.direction == LegacyFunctionSpecification.OUT or parameter.direction == LegacyFunctionSpecification.INOUT: - self.out.lf() + 'reply.set' + spec.output_var_name + 'Slice(' + parameter.output_index + ', ' + parameter.name + ');' - + + if ( + parameter.direction == LegacyFunctionSpecification.OUT + or parameter.direction == LegacyFunctionSpecification.INOUT + ): + ( + self.out.lf() + + "reply.set" + + spec.output_var_name + + "Slice(" + + parameter.output_index + + ", " + + parameter.name + + ");" + ) + def output_casestmt_end(self): - self.out.n() + 'break;' + self.out.n() + "break;" -class GenerateAJavaFunctionDeclarationStringFromAFunctionSpecification(MakeJavaCodeString): - - + +class GenerateAJavaFunctionDeclarationStringFromAFunctionSpecification( + MakeJavaCodeString +): def start(self): - #must and can handle array is the same thing in Java codes... + # must and can handle array is the same thing in Java codes... if self.specification.can_handle_array: self.specification.must_handle_array = True - + self.output_function_parameter_types() self.output_function_start() self.output_function_parameters() self.output_function_end() self._result = self.out.string - - def output_function_parameter_types(self): + + def output_function_parameter_types(self): for parameter in self.specification.parameters: - if (parameter.direction == LegacyFunctionSpecification.IN): - self.out.lf() + '// parameter "' + parameter.name + '" is an input parameter' - elif (parameter.direction == LegacyFunctionSpecification.OUT): - self.out.lf() + '// parameter "' + parameter.name + '" is an output parameter' - elif (parameter.direction == LegacyFunctionSpecification.INOUT): - self.out.lf() + '// parameter "' + parameter.name + '" is an inout parameter' - elif (parameter.direction == LegacyFunctionSpecification.LENGTH): - self.out.lf() + '// parameter "' + parameter.name + '" is a length parameter' - - def output_function_parameters(self): + if parameter.direction == LegacyFunctionSpecification.IN: + ( + self.out.lf() + + '// parameter "' + + parameter.name + + '" is an input parameter' + ) + elif parameter.direction == LegacyFunctionSpecification.OUT: + ( + self.out.lf() + + '// parameter "' + + parameter.name + + '" is an output parameter' + ) + elif parameter.direction == LegacyFunctionSpecification.INOUT: + ( + self.out.lf() + + '// parameter "' + + parameter.name + + '" is an inout parameter' + ) + elif parameter.direction == LegacyFunctionSpecification.LENGTH: + ( + self.out.lf() + + '// parameter "' + + parameter.name + + '" is a length parameter' + ) + + def output_function_parameters(self): first = True - + for parameter in self.specification.parameters: spec = self.dtype_to_spec[parameter.datatype] - + if first: first = False else: - self.out + ', ' - + self.out + ", " + self.out + spec.type - if ((self.specification.must_handle_array and parameter.is_input()) or parameter.is_output()): - self.out + '[]' - - self.out + ' ' + if ( + self.specification.must_handle_array and parameter.is_input() + ) or parameter.is_output(): + self.out + "[]" + + self.out + " " self.out + parameter.name - - + def output_function_end(self): - self.out + ')' + ';' - + self.out + ")" + ";" + def output_function_start(self): self.out.n() if not self.specification.result_type is None: spec = self.dtype_to_spec[self.specification.result_type] self.out + spec.type - self.out + ' ' + self.out + " " else: - self.out + 'void' + ' ' - self.out + self.specification.name + '(' - -class GenerateAJavaSourcecodeStringFromASpecificationClass\ - (GenerateASourcecodeStringFromASpecificationClass): + self.out + "void" + " " + self.out + self.specification.name + "(" + +class GenerateAJavaSourcecodeStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): @late def specification_class(self): - raise exceptions.AmuseException("No specification_class set, please set the specification_class first") - + raise exceptions.AmuseException( + "No specification_class set, please set the specification_class first" + ) + @late def dtype_to_spec(self): return dtype_to_spec def output_sourcecode_for_function(self): return GenerateAJavaStringOfAFunctionSpecification() - + def start(self): - self.out.lf() self.out + IMPORTS_CODE_STRING - - self.out.lf() + 'class Worker {' + + self.out.lf() + "class Worker {" self.out.indent().lf() self.out + AMUSE_MESSAGE_CLASS_CODE_STRING - self.output_handle_call() - + self.out.lf() + FOOTER_CODE_STRING self.out.dedent().lf() - + self.out.lf() + "}" - + self._result = self.out.string - + def output_code_constants(self): for dtype in list(self.dtype_to_spec.keys()): dtype_spec = self.dtype_to_spec[dtype] - - maxin = self.mapping_from_dtype_to_maximum_number_of_inputvariables.get(dtype, 0) - self.out + 'static int MAX_' + dtype_spec.input_var_name.upper() + ' = ' + maxin + ";" + + maxin = self.mapping_from_dtype_to_maximum_number_of_inputvariables.get( + dtype, 0 + ) + ( + self.out + + "static int MAX_" + + dtype_spec.input_var_name.upper() + + " = " + + maxin + + ";" + ) self.out.lf() - - maxout = self.mapping_from_dtype_to_maximum_number_of_outputvariables.get(dtype, 0) - self.out + 'static int MAX_' + dtype_spec.output_var_name.upper() + ' = ' + maxout + ";" + + maxout = self.mapping_from_dtype_to_maximum_number_of_outputvariables.get( + dtype, 0 + ) + ( + self.out + + "static int MAX_" + + dtype_spec.output_var_name.upper() + + " = " + + maxout + + ";" + ) self.out.lf() - + def output_handle_call(self): - - self.out.lf().lf() + 'private boolean handleCall() throws IOException {' + self.out.lf().lf() + "private boolean handleCall() throws IOException {" self.out.indent() - - self.out.lf() + 'int count = request.getCallCount();' - - self.out.lf().lf() + 'switch (request.getFunctionID()) {' + + self.out.lf() + "int count = request.getCallCount();" + + self.out.lf().lf() + "switch (request.getFunctionID()) {" self.out.indent() - self.out.lf() + 'case 0:' + self.out.lf() + "case 0:" self.out.indent() - self.out.lf() + 'code.end();' - self.out.lf() + 'return false;' + self.out.lf() + "code.end();" + self.out.lf() + "return false;" self.out.dedent() - + self.output_sourcecode_for_functions() - - self.out.lf() + 'default:' + + self.out.lf() + "default:" self.out.indent() - self.out.lf() + 'System.err.println("unknown function id " + request.getFunctionID());' - self.out.lf() + 'reply.setError("unknown function id " + request.getFunctionID());' + ( + self.out.lf() + + 'System.err.println("unknown function id " + request.getFunctionID());' + ) + ( + self.out.lf() + + 'reply.setError("unknown function id " + request.getFunctionID());' + ) self.out.dedent() - - self.out.dedent().lf() + '}' + + self.out.dedent().lf() + "}" self.out.dedent() - self.out.indent().lf() + 'return true;' - self.out.dedent().lf() + '}' + self.out.indent().lf() + "return true;" + self.out.dedent().lf() + "}" -class GenerateAJavaInterfaceStringFromASpecificationClass\ - (GenerateASourcecodeStringFromASpecificationClass): +class GenerateAJavaInterfaceStringFromASpecificationClass( + GenerateASourcecodeStringFromASpecificationClass +): @late def ignore_functions_from_specification_classes(self): return [] - + @late def underscore_functions_from_specification_classes(self): return [] - + @late def dtype_to_spec(self): return dtype_to_spec - - + def must_include_interface_function_in_output(self, x): - if hasattr(x.specification,"internal_provided"): + if hasattr(x.specification, "internal_provided"): return False - + for cls in self.ignore_functions_from_specification_classes: if hasattr(cls, x.specification.name): return False - + return True - + def output_sourcecode_for_function(self): return GenerateAJavaFunctionDeclarationStringFromAFunctionSpecification() - - def start(self): - self.out + 'public interface CodeInterface {' + + def start(self): + self.out + "public interface CodeInterface {" self.out.indent().lf() - - self.out + 'public void end();' - self.out.lf() - + + self.out + "public void end();" + self.out.lf() + self.output_sourcecode_for_functions() - - self.out.dedent().lf() + '}' - + + self.out.dedent().lf() + "}" + self.out.lf() - + self._result = self.out.string - -class GenerateAJavaWorkerScript(GenerateASourcecodeString): + +class GenerateAJavaWorkerScript(GenerateASourcecodeString): @late def amuse_root_dir(self): return os.path.abspath(options.GlobalOptions.instance().amuse_rootdirectory) @late def code_dir(self): - codedir=os.path.split(self.code_directory())[-1] + codedir = os.path.split(self.code_directory())[-1] return os.path.join("community", codedir) @late @@ -1484,37 +1569,35 @@ def java(self): @late def template_dir(self): return os.path.dirname(__file__) - + @late def template_string(self): path = self.template_dir - path = os.path.join(path, 'java_code_script.template') - + path = os.path.join(path, "java_code_script.template") + with open(path, "r") as f: template_string = f.read() - + return template_string - + @staticmethod def classpath(classpath, code_dir): return ":".join([os.path.join(code_dir, x) for x in classpath]) - + def script_string(self): return self.template_string.format( - executable = sys.executable, - java = self.java, - classpath = self.classpath(self.specification_class.classpath, self.code_dir), - code_dir = self.code_dir, - amuse_root_dir = self.amuse_root_dir - ) - + executable=sys.executable, + java=self.java, + classpath=self.classpath(self.specification_class.classpath, self.code_dir), + code_dir=self.code_dir, + amuse_root_dir=self.amuse_root_dir, + ) def code_directory(self): interface_module = inspect.getmodule(self.specification_class).__name__ return os.path.dirname(inspect.getfile(self.specification_class)) - + def start(self): self.out + self.script_string() - - self._result = self.out.string + self._result = self.out.string diff --git a/src/amuse/rfi/tools/create_python_worker.py b/src/amuse/rfi/tools/create_python_worker.py index 8787233eeb..82b81cc895 100644 --- a/src/amuse/rfi/tools/create_python_worker.py +++ b/src/amuse/rfi/tools/create_python_worker.py @@ -7,75 +7,71 @@ import os import inspect import sys - class CreateAPythonWorker(OptionalAttributes): - - @option(sections=['data']) + @option(sections=["data"]) def amuse_root_dir(self): return get_amuse_root_dir() - + @late def channel_type(self): - return 'mpi' - + return "mpi" + @late def template_dir(self): return os.path.dirname(__file__) - + @late def worker_dir(self): return os.path.abspath(os.path.curdir) - + @late def template_string(self): path = self.template_dir - path = os.path.join(path, 'python_code_script.template') - + path = os.path.join(path, "python_code_script.template") + with open(path, "r") as f: template_string = f.read() - + return template_string - - + @late def worker_name(self): filename = os.path.basename(inspect.getfile(self.implementation_factory)) - filename = filename.split('.')[0] - filename.replace(os.sep, '_') + filename = filename.split(".")[0] + filename.replace(os.sep, "_") path = os.path.join(self.worker_dir, filename) - + return path - + @late def output_name(self): executable_path = self.worker_name return executable_path - + @late def interface_class(self): return self.specification_class - + def new_executable_script_string(self): return self.template_string.format( - executable = sys.executable, - syspath = ','.join(map(repr, sys.path)), - factory_module = inspect.getmodule(self.implementation_factory).__name__, - factory = self.implementation_factory.__name__, - interface_module = inspect.getmodule(self.interface_class).__name__, - interface = self.interface_class.__name__, + executable=sys.executable, + syspath=",".join(map(repr, sys.path)), + factory_module=inspect.getmodule(self.implementation_factory).__name__, + factory=self.implementation_factory.__name__, + interface_module=inspect.getmodule(self.interface_class).__name__, + interface=self.interface_class.__name__, ) - + @property def result(self): return self.new_executable_script_string() - + def start(self): string = self.new_executable_script_string() - - with open(self.output_name, 'w') as f: + + with open(self.output_name, "w") as f: f.write(string) - + os.chmod(self.output_name, 0o777) - diff --git a/src/amuse/rfi/tools/fortran_tools.py b/src/amuse/rfi/tools/fortran_tools.py index 1a26773816..cc00b05fe4 100644 --- a/src/amuse/rfi/tools/fortran_tools.py +++ b/src/amuse/rfi/tools/fortran_tools.py @@ -1,13 +1,14 @@ -from amuse.support.literature import TrackLiteratureReferences +from amuse.support.literature import TrackLiteratureReferences # ptype: # simple: rw, scalar value and implementation generated by interface -# normal: rw, scalar value, custom implementation -# ro: read only, scalar value, generated +# normal: rw, scalar value, custom implementation +# ro: read only, scalar value, generated # vector: rw, generated vector valued + class FortranCodeGenerator(object): - _getter_string=""" + _getter_string = """ function get_{0}(x) result(ret) integer :: ret {1} :: x @@ -15,7 +16,7 @@ class FortranCodeGenerator(object): ret=0 end function """ - _setter_string=""" + _setter_string = """ function set_{0}(x) result(ret) integer :: ret {1} :: x @@ -23,8 +24,8 @@ class FortranCodeGenerator(object): ret=0 end function """ - - _vector_getter_string=""" + + _vector_getter_string = """ function get_{0}(i,x) result(ret) integer :: i,ret {1} :: x @@ -32,7 +33,7 @@ class FortranCodeGenerator(object): ret=0 end function """ - _vector_setter_string=""" + _vector_setter_string = """ function set_{0}(i,x) result(ret) integer :: i,ret {1} :: x @@ -41,7 +42,7 @@ class FortranCodeGenerator(object): end function """ - _grid_getter_template=""" + _grid_getter_template = """ function get_{0}({2},{0}_out_,n) result(ret) integer :: n,{3},k,ret {1} :: {0}_out_(n) @@ -53,7 +54,7 @@ class FortranCodeGenerator(object): end function """ - _grid_setter_template=""" + _grid_setter_template = """ function set_{0}({2},{0}_in_,n) result(ret) integer :: n,{3},k,ret {1} :: {0}_in_(n) @@ -63,83 +64,109 @@ class FortranCodeGenerator(object): {0}({4})={0}_in_(k) enddo end function - """ - - datatypedict={"string" : "character(len=*) ", - "float64" : "real*8", - "float32" : "real", - "int32" : "integer", - "bool" : "logical", - } + """ + + datatypedict = { + "string": "character(len=*) ", + "float64": "real*8", + "float32": "real", + "int32": "integer", + "bool": "logical", + } - def __init__(self,parameter_definition=None, grid_variable_definition=None): + def __init__(self, parameter_definition=None, grid_variable_definition=None): if parameter_definition is None: - parameter_definition=dict() + parameter_definition = dict() if grid_variable_definition is None: - grid_variable_definition=dict() - self.parameter_definition=parameter_definition - self.grid_variable_definition=grid_variable_definition + grid_variable_definition = dict() + self.parameter_definition = parameter_definition + self.grid_variable_definition = grid_variable_definition + + def _grid_format_arg(self, name, dtype, ndim=1, index_ranges=None): + arg = ",".join(["i{0}".format(i) for i in range(ndim)]) + dec = ",".join(["i{0}(n)".format(i) for i in range(ndim)]) + assign = ",".join(["i{0}(k)".format(i) for i in range(ndim)]) + if index_ranges is None: + check = "! no check on index range" + else: + checks = (".OR. &\n" + " " * 11).join( + [ + "({0}.LT.{1}.OR.{0}.GT.{2})".format( + "i{0}(k)".format(i), index_ranges[i][0], index_ranges[i][1] + ) + for i in range(ndim) + ] + ) + check = ( + "if({0}) then\n".format(checks) + + " " * 10 + + "ret=-1\n" + + " " * 10 + + "exit\n" + + " " * 8 + + "endif" + ) + dtype = self.datatypedict[dtype] + return name, dtype, arg, dec, assign, check - def _grid_format_arg(self,name, dtype,ndim=1, index_ranges=None): - arg=','.join(['i{0}'.format(i) for i in range(ndim)]) - dec=','.join(['i{0}(n)'.format(i) for i in range(ndim)]) - assign=','.join(['i{0}(k)'.format(i) for i in range(ndim)]) - if index_ranges is None: - check="! no check on index range" - else: - checks=(".OR. &\n"+" "*11).join(["({0}.LT.{1}.OR.{0}.GT.{2})".format('i{0}(k)'.format(i), - index_ranges[i][0],index_ranges[i][1]) for i in range(ndim)]) - check="if({0}) then\n".format(checks)+ \ - " "*10+"ret=-1\n"+ \ - " "*10+"exit\n"+ \ - " "*8+"endif" - dtype=self.datatypedict[dtype] - return name,dtype,arg,dec,assign,check - - def grid_getter(self, name, dtype,ndim=1, index_ranges=None): - return self._grid_getter_template.format(*self._grid_format_arg(name,dtype,ndim,index_ranges)) + def grid_getter(self, name, dtype, ndim=1, index_ranges=None): + return self._grid_getter_template.format( + *self._grid_format_arg(name, dtype, ndim, index_ranges) + ) - def grid_setter(self, name, dtype,ndim=1, index_ranges=None): - return self._grid_setter_template.format(*self._grid_format_arg(name,dtype,ndim,index_ranges)) + def grid_setter(self, name, dtype, ndim=1, index_ranges=None): + return self._grid_setter_template.format( + *self._grid_format_arg(name, dtype, ndim, index_ranges) + ) def parameter_getter_setters(self): - filestring="" - py_to_f=self.datatypedict - for par,d in self.parameter_definition.items(): - if d["ptype"] in ["ro"]: - filestring+=self._getter_string.format(d["short"],py_to_f[d["dtype"]]) - if d["ptype"] in ["simple"]: - filestring+=self._setter_string.format(d["short"],py_to_f[d["dtype"]]) - filestring+=self._getter_string.format(d["short"],py_to_f[d["dtype"]]) - if d["ptype"] in ["vector"]: - filestring+=self._vector_setter_string.format(d["short"],py_to_f[d["dtype"]]) - filestring+=self._vector_getter_string.format(d["short"],py_to_f[d["dtype"]]) - + filestring = "" + py_to_f = self.datatypedict + for par, d in self.parameter_definition.items(): + if d["ptype"] in ["ro"]: + filestring += self._getter_string.format( + d["short"], py_to_f[d["dtype"]] + ) + if d["ptype"] in ["simple"]: + filestring += self._setter_string.format( + d["short"], py_to_f[d["dtype"]] + ) + filestring += self._getter_string.format( + d["short"], py_to_f[d["dtype"]] + ) + if d["ptype"] in ["vector"]: + filestring += self._vector_setter_string.format( + d["short"], py_to_f[d["dtype"]] + ) + filestring += self._vector_getter_string.format( + d["short"], py_to_f[d["dtype"]] + ) + return filestring def grid_getter_setters(self): - string="" + string = "" for var, d in self.grid_variable_definition.items(): - vartype=d.get("vartype", None) - dtype=d.get("dtype", "float64") - ndim=d.get("ndim", 1) - index_ranges=d.get("index_ranges",None) + vartype = d.get("vartype", None) + dtype = d.get("dtype", "float64") + ndim = d.get("ndim", 1) + index_ranges = d.get("index_ranges", None) for forvar in d["forvar"]: - string+=self.grid_getter(forvar, dtype, ndim, index_ranges) - if vartype!="ro": - string+=self.grid_setter(forvar, dtype, ndim, index_ranges) + string += self.grid_getter(forvar, dtype, ndim, index_ranges) + if vartype != "ro": + string += self.grid_setter(forvar, dtype, ndim, index_ranges) return string - def generate_getters_setters(self,filename=None): - filestring="" - #~ filestring+=input_grid_string(_unstructured_input_grid_template) - #~ filestring+=input_grid_string(_regular_input_grid_template) - filestring+=self.parameter_getter_setters() - filestring+=self.grid_getter_setters() + def generate_getters_setters(self, filename=None): + filestring = "" + # filestring+=input_grid_string(_unstructured_input_grid_template) + # filestring+=input_grid_string(_regular_input_grid_template) + filestring += self.parameter_getter_setters() + filestring += self.grid_getter_setters() if filename is None: return filestring else: - with open(filename,"w") as f: + with open(filename, "w") as f: f.write(filestring) def print_getters_setters(self): @@ -147,119 +174,212 @@ def print_getters_setters(self): print(self.generate_getters_setters()) def generate_parameter_interface_functions(self): - output="" - for par,d in self.parameter_definition.items(): - dtype=d["dtype"] - if hasattr(d["default"],"unit"): - unit=d["default"].unit.reference_string() + output = "" + for par, d in self.parameter_definition.items(): + dtype = d["dtype"] + if hasattr(d["default"], "unit"): + unit = d["default"].unit.reference_string() else: - unit="None" - short=d["short"] - ptype=d["ptype"] + unit = "None" + short = d["short"] + ptype = d["ptype"] if ptype in ["ro"]: - output+=("@legacy_function\ndef get_"+short+"():\n function = LegacyFunctionSpecification()\n" - " function.addParameter('"+short+"', dtype='"+dtype+"', direction=function.OUT, unit="+unit+")\n" - " function.result_type = 'int32'\n return function\n") + output += ( + "@legacy_function\ndef get_" + + short + + "():\n function = LegacyFunctionSpecification()\n" + " function.addParameter('" + + short + + "', dtype='" + + dtype + + "', direction=function.OUT, unit=" + + unit + + ")\n" + " function.result_type = 'int32'\n return function\n" + ) if ptype in ["simple"]: - output+=("@legacy_function\ndef get_"+short+"():\n function = LegacyFunctionSpecification()\n" - " function.addParameter('"+short+"', dtype='"+dtype+"', direction=function.OUT, unit="+unit+")\n" - " function.result_type = 'int32'\n return function\n") - output+=("@legacy_function\ndef set_"+short+"():\n function = LegacyFunctionSpecification()\n" - " function.addParameter('"+short+"', dtype='"+dtype+"', direction=function.IN, unit="+unit+")\n" - " function.result_type = 'int32'\n return function\n") + output += ( + "@legacy_function\ndef get_" + + short + + "():\n function = LegacyFunctionSpecification()\n" + " function.addParameter('" + + short + + "', dtype='" + + dtype + + "', direction=function.OUT, unit=" + + unit + + ")\n" + " function.result_type = 'int32'\n return function\n" + ) + output += ( + "@legacy_function\ndef set_" + + short + + "():\n function = LegacyFunctionSpecification()\n" + " function.addParameter('" + + short + + "', dtype='" + + dtype + + "', direction=function.IN, unit=" + + unit + + ")\n" + " function.result_type = 'int32'\n return function\n" + ) if ptype in ["vector"]: - output+=("@legacy_function\ndef get_"+short+"():\n function = LegacyFunctionSpecification()\n" - " function.addParameter('i', dtype='i', direction=function.IN)\n" - " function.addParameter('"+short+"', dtype='"+dtype+"', direction=function.OUT, unit="+unit+")\n" - " function.can_handle_array=True\n" - " function.result_type = 'int32'\n return function\n") - output+=("@legacy_function\ndef set_"+short+"():\n function = LegacyFunctionSpecification()\n" - " function.addParameter('i', dtype='i', direction=function.IN)\n" - " function.addParameter('"+short+"', dtype='"+dtype+"', direction=function.IN, unit="+unit+")\n" - " function.can_handle_array=True\n" - " function.result_type = 'int32'\n return function\n") - length=d["length"] - output+=( "def get_"+short+"_range(self):\n" + ( - (" return 1," + str(length)) if isinstance(length, int) else - (" return 1, self.get_"+length+"()['"+length+"']\n") ) + output += ( + "@legacy_function\ndef get_" + + short + + "():\n function = LegacyFunctionSpecification()\n" + " function.addParameter('i', dtype='i', direction=function.IN)\n" + " function.addParameter('" + + short + + "', dtype='" + + dtype + + "', direction=function.OUT, unit=" + + unit + + ")\n" + " function.can_handle_array=True\n" + " function.result_type = 'int32'\n return function\n" + ) + output += ( + "@legacy_function\ndef set_" + + short + + "():\n function = LegacyFunctionSpecification()\n" + " function.addParameter('i', dtype='i', direction=function.IN)\n" + " function.addParameter('" + + short + + "', dtype='" + + dtype + + "', direction=function.IN, unit=" + + unit + + ")\n" + " function.can_handle_array=True\n" + " function.result_type = 'int32'\n return function\n" + ) + length = d["length"] + output += ( + "def get_" + + short + + "_range(self):\n" + + ( + (" return 1," + str(length)) + if isinstance(length, int) + else ( + " return 1, self.get_" + length + "()['" + length + "']\n" + ) ) + ) return output def generate_grid_interface_functions(self): - output="" + output = "" for var, d in self.grid_variable_definition.items(): - vartype=d.get("vartype", None) - dtype=d.get("dtype", "float64") - dtype=dtype.__name__ if isinstance(dtype, type) else str(dtype) - ndim=d.get("ndim", 1) - index_ranges=d.get("index_ranges",None) - unit=d.get("unit",None) - unit="None" if unit is None else unit.reference_string() + vartype = d.get("vartype", None) + dtype = d.get("dtype", "float64") + dtype = dtype.__name__ if isinstance(dtype, type) else str(dtype) + ndim = d.get("ndim", 1) + index_ranges = d.get("index_ranges", None) + unit = d.get("unit", None) + unit = "None" if unit is None else unit.reference_string() for pyvar, forvar in zip(d["pyvar"], d["forvar"]): - if vartype!="ro": - output+=("@legacy_function\ndef set_"+forvar+"():\n function = LegacyFunctionSpecification()\n" + \ - "".join([" function.addParameter('index{0}', dtype='i', direction=function.IN)\n".format(i) for i in range(ndim)]) + \ - " function.addParameter('"+pyvar+"', dtype='"+dtype+"', direction=function.IN, unit="+unit+")\n" + \ - " function.addParameter('n', direction=function.LENGTH)\n" + \ - " function.must_handle_array=True\n" + \ - " function.result_type = 'int32'\n return function\n") - output+=("@legacy_function\ndef get_"+forvar+"():\n function = LegacyFunctionSpecification()\n" + \ - "".join([" function.addParameter('index{0}', dtype='i', direction=function.IN)\n".format(i) for i in range(ndim)]) + \ - " function.addParameter('"+pyvar+"', dtype='"+dtype+"', direction=function.OUT, unit="+unit+")\n" + \ - " function.addParameter('n', direction=function.LENGTH)\n" + \ - " function.must_handle_array=True\n" + \ - " function.result_type = 'int32'\n return function\n") + if vartype != "ro": + output += ( + "@legacy_function\ndef set_" + + forvar + + "():\n function = LegacyFunctionSpecification()\n" + + "".join( + [ + " function.addParameter('index{0}', dtype='i', direction=function.IN)\n".format( + i + ) + for i in range(ndim) + ] + ) + + " function.addParameter('" + + pyvar + + "', dtype='" + + dtype + + "', direction=function.IN, unit=" + + unit + + ")\n" + + " function.addParameter('n', direction=function.LENGTH)\n" + + " function.must_handle_array=True\n" + + " function.result_type = 'int32'\n return function\n" + ) + output += ( + "@legacy_function\ndef get_" + + forvar + + "():\n function = LegacyFunctionSpecification()\n" + + "".join( + [ + " function.addParameter('index{0}', dtype='i', direction=function.IN)\n".format( + i + ) + for i in range(ndim) + ] + ) + + " function.addParameter('" + + pyvar + + "', dtype='" + + dtype + + "', direction=function.OUT, unit=" + + unit + + ")\n" + + " function.addParameter('n', direction=function.LENGTH)\n" + + " function.must_handle_array=True\n" + + " function.result_type = 'int32'\n return function\n" + ) return output def generate_interface_functions(self): - output="" - output+=self.generate_parameter_interface_functions() - output+=self.generate_grid_interface_functions() + output = "" + output += self.generate_parameter_interface_functions() + output += self.generate_grid_interface_functions() return output - def generate_parameter_definitions(self, object): - for name,d in self.parameter_definition.items(): - short=d["short"] - ptype=d["ptype"] - dtype=d["dtype"] - getter="get_"+short - if ptype in ["simple","normal","vector"]: - setter="set_"+short + def generate_parameter_definitions(self, object): + for name, d in self.parameter_definition.items(): + short = d["short"] + ptype = d["ptype"] + dtype = d["dtype"] + getter = "get_" + short + if ptype in ["simple", "normal", "vector"]: + setter = "set_" + short else: - setter=None - range_method="get_"+short+"_range" + setter = None + range_method = "get_" + short + "_range" if ptype in ["simple", "normal", "ro"]: - if dtype!='bool': - object.add_method_parameter( - getter, - setter, - name, - d["description"], - d["default"] - ) - else: - object.add_boolean_parameter( - getter, - setter, - name, - d["description"], - d["default"] - ) + if dtype != "bool": + object.add_method_parameter( + getter, setter, name, d["description"], d["default"] + ) + else: + object.add_boolean_parameter( + getter, setter, name, d["description"], d["default"] + ) else: - object.add_array_parameter( - getter, - setter, - range_method, - name, - d["description"] - ) + object.add_array_parameter( + getter, setter, range_method, name, d["description"] + ) + -if __name__=="__main__": - grid_var={ - "pressure" : dict( pyvar=["pressure"], forvar=["pom"], dtype="float64", ndim=1, index_ranges=[(1,10)],vartype="ro"), - "test2" : dict( pyvar=["y"], forvar=["y"], dtype="float64", ndim=2, index_ranges=[(1,"nla"),(1,"nla")]) - } - f=FortranCodeGenerator(grid_variable_definition=grid_var) - print(f.grid_getter_setters()) - print(f.generate_grid_interface_functions()) - +if __name__ == "__main__": + grid_var = { + "pressure": dict( + pyvar=["pressure"], + forvar=["pom"], + dtype="float64", + ndim=1, + index_ranges=[(1, 10)], + vartype="ro", + ), + "test2": dict( + pyvar=["y"], + forvar=["y"], + dtype="float64", + ndim=2, + index_ranges=[(1, "nla"), (1, "nla")], + ), + } + f = FortranCodeGenerator(grid_variable_definition=grid_var) + print(f.grid_getter_setters()) + print(f.generate_grid_interface_functions()) From 58e185b91a363e9b6ee9a48e0efb2118945b5178 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Fri, 11 Oct 2024 15:04:06 +0200 Subject: [PATCH 08/12] Fix an oops --- src/amuse/io/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/amuse/io/base.py b/src/amuse/io/base.py index 53b7bf9cce..ec4ae9f4de 100644 --- a/src/amuse/io/base.py +++ b/src/amuse/io/base.py @@ -503,9 +503,9 @@ def read_fortran_block_float_vectors(self, file, size=3): result = self.read_fortran_block_floats(file) return result.reshape(len(result) // size, size) - def write_fortran_block(self, file, input): + def write_fortran_block(self, file, input_raw): fileformat = self.endianness + "I" - input_bytes = bytearray(fileformat) + input_bytes = bytearray(input_raw) length_of_block = len(input_bytes) file.write(struct.pack(fileformat, length_of_block)) file.write(input_bytes) From b870fd813337a434223634ca5baa75cf12a29f92 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Mon, 14 Oct 2024 10:53:24 +0200 Subject: [PATCH 09/12] remove (object) from class --- src/amuse/rfi/async_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/amuse/rfi/async_request.py b/src/amuse/rfi/async_request.py index 16d0711550..4387f753cd 100644 --- a/src/amuse/rfi/async_request.py +++ b/src/amuse/rfi/async_request.py @@ -4,7 +4,7 @@ from . import channel -class AbstractASyncRequest(object): +class AbstractASyncRequest: def __bool__(self): return not self.is_finished From 57115b9f826871a3cca545813a40904d4c17dd9b Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Mon, 14 Oct 2024 10:54:06 +0200 Subject: [PATCH 10/12] revised with Black and manual updates --- src/amuse/rfi/channel.py | 2099 ++++++++++++++++++++++---------------- 1 file changed, 1241 insertions(+), 858 deletions(-) diff --git a/src/amuse/rfi/channel.py b/src/amuse/rfi/channel.py index bac733ea38..6a25c5d767 100644 --- a/src/amuse/rfi/channel.py +++ b/src/amuse/rfi/channel.py @@ -1,9 +1,8 @@ +import sys import inspect -import numpy import os.path -import pickle as pickle +import pickle -import sys import struct import threading import select @@ -14,24 +13,15 @@ import array import logging import shlex +import numpy -logger = logging.getLogger(__name__) - -# -# we want to use the automatic initialization and finalization -# of the MPI library, but sometime MPI should not be imported -# when importing the channel -# so actual import is in function ensure_mpi_initialized -# -MPI = None - from subprocess import Popen, PIPE try: from amuse import config except ImportError: config = None - + from amuse.support.options import OptionalAttributes, option, GlobalOptions from amuse.support.core import late from amuse.support import exceptions @@ -43,7 +33,18 @@ from . import async_request -class AbstractMessage(object): +logger = logging.getLogger(__name__) + +# +# we want to use the automatic initialization and finalization +# of the MPI library, but sometime MPI should not be imported +# when importing the channel +# so actual import is in function ensure_mpi_initialized +# +MPI = None + + +class AbstractMessage: def __init__( self, call_id=0, @@ -51,20 +52,21 @@ def __init__( call_count=1, dtype_to_arguments={}, error=False, - big_endian=(sys.byteorder.lower() == 'big'), + big_endian=(sys.byteorder.lower() == "big"), polling_interval=0, - encoded_units = ()): + encoded_units=(), + ): self.polling_interval = polling_interval - + # flags self.big_endian = big_endian self.error = error - + # header self.call_id = call_id self.function_id = function_id self.call_count = call_count - + # data (numpy arrays) self.ints = [] self.longs = [] @@ -74,63 +76,62 @@ def __init__( self.booleans = [] self.pack_data(dtype_to_arguments) - + self.encoded_units = encoded_units - def pack_data(self, dtype_to_arguments): for dtype, attrname in self.dtype_to_message_attribute(): if dtype in dtype_to_arguments: array = pack_array(dtype_to_arguments[dtype], self.call_count, dtype) setattr(self, attrname, array) - + def to_result(self, handle_as_array=False): dtype_to_result = {} for dtype, attrname in self.dtype_to_message_attribute(): result = getattr(self, attrname) if self.call_count > 1 or handle_as_array: - dtype_to_result[dtype] = unpack_array(result , self.call_count, dtype) + dtype_to_result[dtype] = unpack_array(result, self.call_count, dtype) else: dtype_to_result[dtype] = result - + return dtype_to_result - + def dtype_to_message_attribute(self): return ( - ('int32', 'ints'), - ('int64', 'longs'), - ('float32', 'floats'), - ('float64', 'doubles'), - ('bool', 'booleans'), - ('string', 'strings'), + ("int32", "ints"), + ("int64", "longs"), + ("float32", "floats"), + ("float64", "doubles"), + ("bool", "booleans"), + ("string", "strings"), ) - + def receive(self, comm): raise NotImplementedError - + def send(self, comm): raise NotImplementedError - + def set_error(self, message): self.strings = [message] self.error = True - - + + class MPIMessage(AbstractMessage): def receive(self, comm): header = self.receive_header(comm) self.receive_content(comm, header) - + def receive_header(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") self.mpi_receive(comm, [header, MPI.INT]) return header - + def receive_content(self, comm, header): # 4 flags as 8bit booleans in 1st 4 bytes of header - # endiannes(not supported by MPI channel), error, unused, unused + # endiannes(not supported by MPI channel), error, unused, unused - flags = header.view(dtype='bool_') + flags = header.view(dtype="bool_") self.big_endian = flags[0] self.error = flags[1] self.is_continued = flags[2] @@ -146,117 +147,119 @@ def receive_content(self, comm, header): number_of_booleans = header[8] number_of_strings = header[9] number_of_units = header[10] - + self.ints = self.receive_ints(comm, number_of_ints) self.longs = self.receive_longs(comm, number_of_longs) self.floats = self.receive_floats(comm, number_of_floats) self.doubles = self.receive_doubles(comm, number_of_doubles) self.booleans = self.receive_booleans(comm, number_of_booleans) self.strings = self.receive_strings(comm, number_of_strings) - + self.encoded_units = self.receive_doubles(comm, number_of_units) - def nonblocking_receive(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") request = self.mpi_nonblocking_receive(comm, [header, MPI.INT]) return async_request.ASyncRequest(request, self, comm, header) - + def receive_doubles(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='d') + result = numpy.empty(total, dtype="d") self.mpi_receive(comm, [result, MPI.DOUBLE]) return result else: return [] - + def receive_ints(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='i') + result = numpy.empty(total, dtype="i") self.mpi_receive(comm, [result, MPI.INT]) return result else: return [] - + def receive_longs(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='int64') + result = numpy.empty(total, dtype="int64") self.mpi_receive(comm, [result, MPI.INTEGER8]) return result else: return [] - + def receive_floats(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='f') + result = numpy.empty(total, dtype="f") self.mpi_receive(comm, [result, MPI.FLOAT]) return result else: return [] - - + def receive_booleans(self, comm, total): if total > 0: - result = numpy.empty(total, dtype='b') - self.mpi_receive(comm, [result, MPI.C_BOOL or MPI.BYTE]) # if C_BOOL null datatype (ie undefined) fallback + result = numpy.empty(total, dtype="b") + self.mpi_receive( + comm, [result, MPI.C_BOOL or MPI.BYTE] + ) # if C_BOOL null datatype (ie undefined) fallback return numpy.logical_not(result == 0) else: return [] - - + def receive_strings(self, comm, total): if total > 0: - sizes = numpy.empty(total, dtype='i') - + sizes = numpy.empty(total, dtype="i") + self.mpi_receive(comm, [sizes, MPI.INT]) - + logger.debug("got %d strings of size %s", total, sizes) - + byte_size = 0 for size in sizes: byte_size = byte_size + size + 1 - + data_bytes = numpy.empty(byte_size, dtype=numpy.uint8) self.mpi_receive(comm, [data_bytes, MPI.CHARACTER]) - + strings = [] begin = 0 for size in sizes: - strings.append(data_bytes[begin:begin + size].tobytes().decode('latin_1')) + strings.append( + data_bytes[begin : begin + size].tobytes().decode("latin_1") + ) begin = begin + size + 1 - + logger.debug("got %d strings of size %s, data = %s", total, sizes, strings) return numpy.array(strings) else: return [] - - + def send(self, comm): - header = numpy.array([ - 0, - self.call_id, - self.function_id, - self.call_count, - len(self.ints) , - len(self.longs) , - len(self.floats) , - len(self.doubles) , - len(self.booleans) , - len(self.strings) , - len(self.encoded_units) - ], dtype='i') - - - flags = header.view(dtype='bool_') + header = numpy.array( + [ + 0, + self.call_id, + self.function_id, + self.call_count, + len(self.ints), + len(self.longs), + len(self.floats), + len(self.doubles), + len(self.booleans), + len(self.strings), + len(self.encoded_units), + ], + dtype="i", + ) + + flags = header.view(dtype="bool_") flags[0] = self.big_endian flags[1] = self.error flags[2] = len(self.encoded_units) > 0 self.send_header(comm, header) self.send_content(comm) - + def send_header(self, comm, header): self.mpi_send(comm, [header, MPI.INT]) - + def send_content(self, comm): self.send_ints(comm, self.ints) self.send_longs(comm, self.longs) @@ -265,100 +268,101 @@ def send_content(self, comm): self.send_booleans(comm, self.booleans) self.send_strings(comm, self.strings) self.send_doubles(comm, self.encoded_units) - def send_ints(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='int32') + sendbuffer = numpy.array(array, dtype="int32") self.mpi_send(comm, [sendbuffer, MPI.INT]) - + def send_longs(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='int64') - self.mpi_send(comm, [sendbuffer, MPI.INTEGER8]) - + sendbuffer = numpy.array(array, dtype="int64") + self.mpi_send(comm, [sendbuffer, MPI.INTEGER8]) + def send_doubles(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='d') + sendbuffer = numpy.array(array, dtype="d") self.mpi_send(comm, [sendbuffer, MPI.DOUBLE]) - + def send_floats(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='f') + sendbuffer = numpy.array(array, dtype="f") self.mpi_send(comm, [sendbuffer, MPI.FLOAT]) - + def send_strings(self, comm, array): if len(array) == 0: return - - lengths = numpy.array( [len(s) for s in array] ,dtype='i') - - chars=(chr(0).join(array)+chr(0)).encode("utf-8") - chars = numpy.frombuffer(chars, dtype='uint8') - if len(chars) != lengths.sum()+len(lengths): - raise Exception("send_strings size mismatch {0} vs {1}".format( len(chars) , lengths.sum()+len(lengths) )) + lengths = numpy.array([len(s) for s in array], dtype="i") + + chars = (chr(0).join(array) + chr(0)).encode("utf-8") + chars = numpy.frombuffer(chars, dtype="uint8") + + if len(chars) != lengths.sum() + len(lengths): + raise Exception( + "send_strings size mismatch {0} vs {1}".format( + len(chars), lengths.sum() + len(lengths) + ) + ) self.mpi_send(comm, [lengths, MPI.INT]) self.mpi_send(comm, [chars, MPI.CHARACTER]) - + def send_booleans(self, comm, array): if len(array) > 0: - sendbuffer = numpy.array(array, dtype='b') + sendbuffer = numpy.array(array, dtype="b") self.mpi_send(comm, [sendbuffer, MPI.C_BOOL or MPI.BYTE]) def set_error(self, message): self.strings = [message] self.error = True - + def mpi_nonblocking_receive(self, comm, array): raise NotImplementedError() - + def mpi_receive(self, comm, array): raise NotImplementedError() - + def mpi_send(self, comm, array): raise NotImplementedError() - - + + class ServerSideMPIMessage(MPIMessage): def mpi_receive(self, comm, array): request = comm.Irecv(array, source=0, tag=999) request.Wait() - + def mpi_send(self, comm, array): comm.Bcast(array, root=MPI.ROOT) - + def send_header(self, comm, array): requests = [] for rank in range(comm.Get_remote_size()): request = comm.Isend(array, dest=rank, tag=989) requests.append(request) MPI.Request.Waitall(requests) - - + def mpi_nonblocking_receive(self, comm, array): return comm.Irecv(array, source=0, tag=999) def receive_header(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") request = self.mpi_nonblocking_receive(comm, [header, MPI.INT]) if self.polling_interval > 0: is_finished = request.Test() while not is_finished: - time.sleep(self.polling_interval / 1000000.) + time.sleep(self.polling_interval / 1000000.0) is_finished = request.Test() request.Wait() else: request.Wait() return header - - + class ClientSideMPIMessage(MPIMessage): def mpi_receive(self, comm, array): comm.Bcast(array, root=0) - + def mpi_send(self, comm, array): comm.Send(array, dest=0, tag=999) @@ -366,22 +370,24 @@ def mpi_nonblocking_receive(self, comm, array): return comm.Irecv(array, source=0, tag=999) def receive_header(self, comm): - header = numpy.zeros(11, dtype='i') + header = numpy.zeros(11, dtype="i") request = comm.Irecv([header, MPI.INT], source=0, tag=989) if self.polling_interval > 0: is_finished = request.Test() while not is_finished: - time.sleep(self.polling_interval / 1000000.) + time.sleep(self.polling_interval / 1000000.0) is_finished = request.Test() request.Wait() else: request.Wait() return header + MAPPING = {} + def pack_array(array, length, dtype): - if dtype == 'string': + if dtype == "string": if length == 1 and len(array) > 0 and isinstance(array[0], str): return array result = [] @@ -401,107 +407,137 @@ def pack_array(array, length, dtype): result = MAPPING.dtype if len(result) != total_length: result = numpy.empty(length * len(array), dtype=dtype) - else: + else: result = numpy.empty(length * len(array), dtype=dtype) - + for i in range(len(array)): offset = i * length - result[offset:offset + length] = array[i] + result[offset : offset + length] = array[i] return result - + def unpack_array(array, length, dtype=None): result = [] total = len(array) // length for i in range(total): offset = i * length - result.append(array[offset:offset + length]) + result.append(array[offset : offset + length]) return result + class AbstractMessageChannel(OptionalAttributes): """ Abstract base class of all message channel. - + A message channel is used to send and retrieve messages from a remote party. A message channel can also setup the remote party. For example starting an instance of an application using MPI calls. - + The messages are encoded as arguments to the send and retrieve methods. Each message has an id and and optional list of doubles, integers, floats and/or strings. - + """ - + def __init__(self, **options): OptionalAttributes.__init__(self, **options) - + @classmethod - def GDB(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e', 'gdb'] - + def GDB( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-hold", "-display", os.environ["DISPLAY"], "-e", "gdb"] + if immediate_run: - arguments.extend([ '-ex', 'run']) - - arguments.extend(['--args']) - + arguments.extend(["-ex", "run"]) + + arguments.extend(["--args"]) + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - command = 'xterm' + + command = "xterm" return command, arguments @classmethod - def LLDB(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e', 'lldb', '--'] + def LLDB( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-hold", "-display", os.environ["DISPLAY"], "-e", "lldb", "--"] if not interpreter_executable is None: arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) - command = 'xterm' + command = "xterm" return command, arguments @classmethod - def DDD(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - if os.name == 'nt': - arguments = [full_name_of_the_worker, "--args",full_name_of_the_worker] + def DDD( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + if os.name == "nt": + arguments = [full_name_of_the_worker, "--args", full_name_of_the_worker] command = channel.adg_exe return command, arguments else: - arguments = ['-display', os.environ['DISPLAY'], '-e', 'ddd', '--args'] - + arguments = ["-display", os.environ["DISPLAY"], "-e", "ddd", "--args"] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - - command = 'xterm' + + command = "xterm" return command, arguments - + @classmethod - def VALGRIND(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def VALGRIND( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): # arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e', 'valgrind', full_name_of_the_worker] arguments = [] - + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - command = 'valgrind' + command = "valgrind" return command, arguments - - + @classmethod - def XTERM(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-hold', '-display', os.environ['DISPLAY'], '-e'] - + def XTERM( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-hold", "-display", os.environ["DISPLAY"], "-e"] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) command = "xterm" @@ -527,50 +563,72 @@ def REDIRECT( if command is None: command = sys.executable - + return command, arguments - + @classmethod - def GDBR(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def GDBR( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): "remote gdb, can run without xterm" - - arguments = ['localhost:{0}'.format(channel.debugger_port)] - + + arguments = ["localhost:{0}".format(channel.debugger_port)] + if not interpreter_executable is None: arguments.append(interpreter_executable) - + arguments.append(full_name_of_the_worker) - + command = channel.gdbserver_exe return command, arguments - + @classmethod - def NODEBUGGER(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def NODEBUGGER( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): if not interpreter_executable is None: return interpreter_executable, [full_name_of_the_worker] else: return full_name_of_the_worker, [] - - + @classmethod - def STRACE(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): - arguments = ['-ostrace-out', '-ff'] + def STRACE( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): + arguments = ["-ostrace-out", "-ff"] if not interpreter_executable is None: arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) - command = 'strace' + command = "strace" return command, arguments - + @classmethod - def CUSTOM(cls, full_name_of_the_worker, channel, interpreter_executable=None, immediate_run=True): + def CUSTOM( + cls, + full_name_of_the_worker, + channel, + interpreter_executable=None, + immediate_run=True, + ): arguments = list(shlex.split(channel.custom_args)) if not interpreter_executable is None: arguments.append(interpreter_executable) arguments.append(full_name_of_the_worker) command = channel.custom_exe return command, arguments - - + @classmethod def is_multithreading_supported(cls): return True @@ -579,110 +637,131 @@ def is_multithreading_supported(cls): def initialize_mpi(self): """Is MPI initialized in the code or not. Defaults to True if MPI is available""" return config.mpi.is_enabled - - @option(type='string', sections=("channel",)) + + @option(type="string", sections=("channel",)) def worker_code_suffix(self): - return '' - - @option(type='string', sections=("channel",)) + return "" + + @option(type="string", sections=("channel",)) def worker_code_prefix(self): - return '' - - @option(type='string', sections=("channel",)) + return "" + + @option(type="string", sections=("channel",)) def worker_code_directory(self): - return '' + return "" @option(type="boolean", sections=("channel",)) def can_redirect_output(self): return True - + @option(sections=("channel",)) def python_exe_for_redirection(self): return None - - + @option(type="int", sections=("channel",)) def debugger_port(self): return 4343 - + @option(type="string", sections=("channel",)) def gdbserver_exe(self): - return 'gdbserver' - + return "gdbserver" + @option(type="string", sections=("channel",)) def adg_exe(self): - return 'adg.exe' - + return "adg.exe" + @option(type="string", sections=("channel",)) def custom_exe(self): - return 'mintty.exe' - + return "mintty.exe" + @option(type="string", sections=("channel",)) def custom_args(self): - return '--hold -e gdb --args' + return "--hold -e gdb --args" - @option(type='boolean', sections=("channel",)) + @option(type="boolean", sections=("channel",)) def debugger_immediate_run(self): return True - - @option(type='boolean', sections=("channel",)) + + @option(type="boolean", sections=("channel",)) def must_check_if_worker_is_up_to_date(self): return True - @option(type='boolean', sections=("channel",)) + @option(type="boolean", sections=("channel",)) def check_worker_location(self): return True - + @option(type="int", sections=("channel",)) def number_of_workers(self): return 1 - + def get_amuse_root_directory(self): return self.amuse_root_dir - - @option(type="string", sections=('data',)) - def amuse_root_dir(self): # needed for location of data, so same as in support.__init__ + + @option(type="string", sections=("data",)) + def amuse_root_dir( + self, + ): # needed for location of data, so same as in support.__init__ return get_amuse_root_dir() - + def check_if_worker_is_up_to_date(self, object): if not self.must_check_if_worker_is_up_to_date: return - + name_of_the_compiled_file = self.full_name_of_the_worker modificationtime_of_worker = os.stat(name_of_the_compiled_file).st_mtime my_class = type(object) for x in dir(my_class): - if x.startswith('__'): + if x.startswith("__"): continue value = getattr(my_class, x) - if hasattr(value, 'crc32'): - is_up_to_date = value.is_compiled_file_up_to_date(modificationtime_of_worker) + if hasattr(value, "crc32"): + is_up_to_date = value.is_compiled_file_up_to_date( + modificationtime_of_worker + ) if not is_up_to_date: - raise exceptions.CodeException("""The worker code of the '{0}' interface class is not up to date. + raise exceptions.CodeException( + """The worker code of the '{0}' interface class is not up to date. Please do a 'make clean; make' in the root directory. -""".format(type(object).__name__)) +""".format( + type(object).__name__ + ) + ) def get_full_name_of_the_worker(self, type): if os.path.isabs(self.name_of_the_worker): - full_name_of_the_worker=self.name_of_the_worker - + full_name_of_the_worker = self.name_of_the_worker + if not self.check_worker_location: return full_name_of_the_worker - + if not os.path.exists(full_name_of_the_worker): - raise exceptions.CodeException("The worker path has been specified, but it is not found: \n{0}".format(full_name_of_the_worker)) + raise exceptions.CodeException( + "The worker path has been specified, but it is not found: \n{0}".format( + full_name_of_the_worker + ) + ) if not os.access(full_name_of_the_worker, os.X_OK): - raise exceptions.CodeException("The worker application exists, but it is not executable.\n{0}".format(full_name_of_the_worker)) - + raise exceptions.CodeException( + "The worker application exists, but it is not executable.\n{0}".format( + full_name_of_the_worker + ) + ) + return full_name_of_the_worker - - exe_name = self.worker_code_prefix + self.name_of_the_worker + self.worker_code_suffix + + exe_name = ( + self.worker_code_prefix + self.name_of_the_worker + self.worker_code_suffix + ) if not self.check_worker_location: if len(self.worker_code_directory) > 0: - full_name_of_the_worker = os.path.join(self.worker_code_directory, exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.join( + self.worker_code_directory, exe_name + ) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) return full_name_of_the_worker else: raise Exception("Must provide a worker_code_directory") @@ -690,22 +769,32 @@ def get_full_name_of_the_worker(self, type): tried_workers = [] directory = os.path.dirname(inspect.getfile(type)) - full_name_of_the_worker = os.path.join(directory, '..','..','_workers', exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.join( + directory, "..", "..", "_workers", exe_name + ) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) - + if len(self.worker_code_directory) > 0: full_name_of_the_worker = os.path.join(self.worker_code_directory, exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) - + directory_of_this_module = os.path.dirname(os.path.dirname(__file__)) - full_name_of_the_worker = os.path.join(directory_of_this_module, '_workers', exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.join( + directory_of_this_module, "_workers", exe_name + ) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) @@ -714,40 +803,51 @@ def get_full_name_of_the_worker(self, type): while not current_type.__bases__[0] is object: directory_of_this_module = os.path.dirname(inspect.getfile(current_type)) full_name_of_the_worker = os.path.join(directory_of_this_module, exe_name) - full_name_of_the_worker = os.path.normpath(os.path.abspath(full_name_of_the_worker)) + full_name_of_the_worker = os.path.normpath( + os.path.abspath(full_name_of_the_worker) + ) if os.path.exists(full_name_of_the_worker): return full_name_of_the_worker tried_workers.append(full_name_of_the_worker) current_type = current_type.__bases__[0] - raise exceptions.CodeException("The worker application does not exist, it should be at: \n{0}".format('\n'.join(tried_workers))) - - def send_message(self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units = None): + raise exceptions.CodeException( + "The worker application does not exist, it should be at: \n{0}".format( + "\n".join(tried_workers) + ) + ) + + def send_message( + self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units=None + ): pass - - def recv_message(self, call_id=0, function_id=-1, handle_as_array=False, has_units = False): + + def recv_message( + self, call_id=0, function_id=-1, handle_as_array=False, has_units=False + ): pass - - def nonblocking_recv_message(self, call_id=0, function_id=-1, handle_as_array=False): + + def nonblocking_recv_message( + self, call_id=0, function_id=-1, handle_as_array=False + ): pass - + def start(self): pass - + def stop(self): pass def is_active(self): return True - + @classmethod def is_root(self): return True - + def is_polling_supported(self): return False - - + def determine_length_from_data(self, dtype_to_arguments): def get_length(type_and_values): argument_type, argument_values = type_and_values @@ -760,13 +860,11 @@ def get_length(type_and_values): except: result = max(result, 1) return result - - - + lengths = [get_length(x) for x in dtype_to_arguments.items()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) def split_message( @@ -774,58 +872,68 @@ def split_message( ): if call_count <= 1: raise Exception("split message called with call_count<=1") - + dtype_to_result = {} - - ndone=0 - while ndone>> is_mpd_running() True - - + + """ if not MpiChannel.is_supported(): return True - + MpiChannel.ensure_mpi_initialized() - + name_of_the_vendor, version = MPI.get_vendor() - if name_of_the_vendor == 'MPICH2': + if name_of_the_vendor == "MPICH2": must_check_mpd = True - if 'AMUSE_MPD_CHECK' in os.environ: - must_check_mpd = os.environ['AMUSE_MPD_CHECK'] == '1' - if 'PMI_PORT' in os.environ: + if "AMUSE_MPD_CHECK" in os.environ: + must_check_mpd = os.environ["AMUSE_MPD_CHECK"] == "1" + if "PMI_PORT" in os.environ: must_check_mpd = False - if 'PMI_RANK' in os.environ: + if "PMI_RANK" in os.environ: must_check_mpd = False - if 'HYDRA_CONTROL_FD' in os.environ: + if "HYDRA_CONTROL_FD" in os.environ: must_check_mpd = False - + if not must_check_mpd: return True try: - process = Popen(['mpdtrace'], stdout=PIPE, stderr=PIPE) + process = Popen(["mpdtrace"], stdout=PIPE, stderr=PIPE) (output_string, error_string) = process.communicate() return not (process.returncode == 255) except OSError as ex: @@ -886,13 +994,14 @@ def is_mpd_running(): class MpiChannel(AbstractMessageChannel): """ Message channel based on MPI calls to send and recv the messages - + :argument name_of_the_worker: Name of the application to start :argument number_of_workers: Number of parallel processes :argument legacy_interface_type: Type of the legacy interface :argument debug_with_gdb: If True opens an xterm with a gdb to debug the remote process :argument hostname: Name of the node to run the application on """ + _mpi_is_broken_after_possible_code_crash = False _intercomms_to_disconnect = [] _is_registered = False @@ -900,68 +1009,81 @@ class MpiChannel(AbstractMessageChannel): _scheduler_index = 0 _scheduler_initialized = False - - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, **options): + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + **options, + ): AbstractMessageChannel.__init__(self, **options) - + self.inuse_semaphore = threading.Semaphore() # logging.basicConfig(level=logging.WARN) # logger.setLevel(logging.DEBUG) # logging.getLogger("code").setLevel(logging.DEBUG) - + self.ensure_mpi_initialized() - + self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + if not legacy_interface_type is None: - self.full_name_of_the_worker = self.get_full_name_of_the_worker(legacy_interface_type) + self.full_name_of_the_worker = self.get_full_name_of_the_worker( + legacy_interface_type + ) else: self.full_name_of_the_worker = self.name_of_the_worker - + if self.check_mpi: if not is_mpd_running(): - raise exceptions.CodeException("The mpd daemon is not running, please make sure it is started before starting this code") - + raise exceptions.CodeException( + "The mpd daemon is not running, please make sure it is started before starting this code" + ) + if self._mpi_is_broken_after_possible_code_crash: - raise exceptions.CodeException("Another code has crashed, cannot spawn a new code, please stop the script and retry") + raise exceptions.CodeException( + "Another code has crashed, cannot spawn a new code, please stop the script and retry" + ) if not self.hostname is None: self.info = MPI.Info.Create() - self.info['host'] = self.hostname + self.info["host"] = self.hostname else: if self.job_scheduler: - self.info = self.get_info_from_job_scheduler(self.job_scheduler, self.number_of_workers) + self.info = self.get_info_from_job_scheduler( + self.job_scheduler, self.number_of_workers + ) else: self.info = MPI.Info.Create() - - for key,value in self.mpi_info_options.items(): - self.info[key]=value - + + for key, value in self.mpi_info_options.items(): + self.info[key] = value + self.cached = None self.intercomm = None self._is_inuse = False self._communicated_splitted_message = False logger.debug("MPI channel created with info items: %s", str(self.info.items())) - @classmethod def ensure_mpi_initialized(cls): global MPI - + if MPI is None: import mpi4py.MPI + MPI = mpi4py.MPI cls.register_finalize_code() @classmethod def is_threaded(cls): - #We want this for backwards compatibility with mpi4py versions < 2.0.0 - #currently unused after Init/Init_threaded was removed from - #this module. + # We want this for backwards compatibility with mpi4py versions < 2.0.0 + # currently unused after Init/Init_threaded was removed from + # this module. from mpi4py import rc + try: return rc.threaded except AttributeError: @@ -972,7 +1094,7 @@ def register_finalize_code(cls): if not cls._is_registered: atexit.register(cls.finialize_mpi_atexit) cls._is_registered = True - + @classmethod def finialize_mpi_atexit(cls): if not MPI.Is_initialized(): @@ -982,26 +1104,26 @@ def finialize_mpi_atexit(cls): try: for x in cls._intercomms_to_disconnect: x.Disconnect() - + except MPI.Exception as ex: return - + @classmethod def is_multithreading_supported(cls): return MPI.Query_thread() == MPI.THREAD_MULTIPLE - + @option(type="boolean", sections=("channel",)) def check_mpi(self): return True - + @option(type="boolean", sections=("channel",)) def debug_with_gdb(self): return False - + @option(sections=("channel",)) def hostname(self): return None - + @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" @@ -1010,74 +1132,76 @@ def debugger(self): @option(type="dict", sections=("channel",)) def mpi_info_options(self): return dict() - + @option(type="int", sections=("channel",)) def max_message_length(self): """ For calls to functions that can handle arrays, MPI messages may get too long for large N. The MPI channel will split long messages into blocks of size max_message_length. - """ + """ return 1000000 - @late def redirect_stdout_file(self): return "/dev/null" - + @late def redirect_stderr_file(self): return "/dev/null" - + @late def debugger_method(self): return self.DEBUGGERS[self.debugger] - + @classmethod def is_supported(cls): - if hasattr(config, 'mpi') and hasattr(config.mpi, 'is_enabled'): + if hasattr(config, "mpi") and hasattr(config.mpi, "is_enabled"): if not config.mpi.is_enabled: return False try: from mpi4py import MPI + return True except ImportError: return False - @option(type="boolean", sections=("channel",)) def can_redirect_output(self): name_of_the_vendor, version = MPI.get_vendor() - if name_of_the_vendor == 'MPICH2': - if 'MPISPAWN_ARGV_0' in os.environ: + if name_of_the_vendor == "MPICH2": + if "MPISPAWN_ARGV_0" in os.environ: return False return True - - + @option(type="boolean", sections=("channel",)) def must_disconnect_on_stop(self): name_of_the_vendor, version = MPI.get_vendor() - if name_of_the_vendor == 'MPICH2': - if 'MPISPAWN_ARGV_0' in os.environ: + if name_of_the_vendor == "MPICH2": + if "MPISPAWN_ARGV_0" in os.environ: return False return True - + @option(type="int", sections=("channel",)) def polling_interval_in_milliseconds(self): return 0 - + @classmethod def is_root(cls): cls.ensure_mpi_initialized() return MPI.COMM_WORLD.rank == 0 - + def start(self): logger.debug("starting mpi worker process") logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) - + if not self.debugger_method is None: - command, arguments = self.debugger_method(self.full_name_of_the_worker, self, - interpreter_executable=self.interpreter_executable, immediate_run=self.debugger_immediate_run) + command, arguments = self.debugger_method( + self.full_name_of_the_worker, + self, + interpreter_executable=self.interpreter_executable, + immediate_run=self.debugger_immediate_run, + ) else: if not self.can_redirect_output or ( self.redirect_stdout_file == "none" @@ -1090,16 +1214,28 @@ def start(self): command = self.interpreter_executable arguments = [self.full_name_of_the_worker] else: - command, arguments = self.REDIRECT(self.full_name_of_the_worker, self.redirect_stdout_file, self.redirect_stderr_file, command=self.python_exe_for_redirection, interpreter_executable=self.interpreter_executable) + command, arguments = self.REDIRECT( + self.full_name_of_the_worker, + self.redirect_stdout_file, + self.redirect_stderr_file, + command=self.python_exe_for_redirection, + interpreter_executable=self.interpreter_executable, + ) - logger.debug("spawning %d mpi processes with command `%s`, arguments `%s` and environment '%s'", self.number_of_workers, command, arguments, os.environ) + logger.debug( + "spawning %d mpi processes with command `%s`, arguments `%s` and environment '%s'", + self.number_of_workers, + command, + arguments, + os.environ, + ) - self.intercomm = MPI.COMM_SELF.Spawn(command, arguments, self.number_of_workers, info=self.info) + self.intercomm = MPI.COMM_SELF.Spawn( + command, arguments, self.number_of_workers, info=self.info + ) logger.debug("worker spawn done") - - - + def stop(self): if not self.intercomm is None: try: @@ -1110,9 +1246,9 @@ def stop(self): except MPI.Exception as ex: if ex.error_class == MPI.ERR_OTHER: type(self)._mpi_is_broken_after_possible_code_crash = True - + self.intercomm = None - + def determine_length_from_datax(self, dtype_to_arguments): def get_length(x): if x: @@ -1122,41 +1258,48 @@ def get_length(x): except: return 1 return 1 - - - + lengths = [get_length(x) for x in dtype_to_arguments.values()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - - + def send_message( self, call_id, function_id, dtype_to_arguments={}, encoded_units=() ): if self.intercomm is None: - raise exceptions.CodeException("You've tried to send a message to a code that is not running") - + raise exceptions.CodeException( + "You've tried to send a message to a code that is not running" + ) + call_count = self.determine_length_from_data(dtype_to_arguments) - + if call_count > self.max_message_length: - self.split_message(call_id, function_id, call_count, dtype_to_arguments, encoded_units) + self.split_message( + call_id, function_id, call_count, dtype_to_arguments, encoded_units + ) else: if self.is_inuse(): - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) self.inuse_semaphore.acquire() try: if self._is_inuse: - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) self._is_inuse = True finally: self.inuse_semaphore.release() message = ServerSideMPIMessage( - call_id, function_id, - call_count, dtype_to_arguments, - encoded_units = encoded_units + call_id, + function_id, + call_count, + dtype_to_arguments, + encoded_units=encoded_units, ) message.send(self.intercomm) @@ -1166,7 +1309,7 @@ def recv_message(self, call_id, function_id, handle_as_array, has_units=False): self._communicated_splitted_message = False del self._merged_results_splitted_message return x - + message = ServerSideMPIMessage( polling_interval=self.polling_interval_in_milliseconds * 1000 ) @@ -1179,74 +1322,103 @@ def recv_message(self, call_id, function_id, handle_as_array, has_units=False): self.inuse_semaphore.acquire() try: if not self._is_inuse: - raise exceptions.CodeException("You've tried to recv a message to a code that is not handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to recv a message to a code that is not handling a message, this is not correct" + ) self._is_inuse = False finally: self.inuse_semaphore.release() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] if len(message.strings) > 0 else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." + self.stop() + error_message += " - code probably died, sorry." raise exceptions.CodeException("Error in code: " + error_message) if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units = False): + def nonblocking_recv_message( + self, call_id, function_id, handle_as_array, has_units=False + ): request = ServerSideMPIMessage().nonblocking_receive(self.intercomm) + def handle_result(function): self._is_inuse = False - + message = function() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] + if len(message.strings) > 0 + else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." - raise exceptions.CodeException("Error in (asynchronous) communication with worker: " + error_message) - + self.stop() + error_message += " - code probably died, sorry." + raise exceptions.CodeException( + "Error in (asynchronous) communication with worker: " + + error_message + ) + if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) - + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) + if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) request.add_result_handler(handle_result) - + return request - + def is_active(self): return self.intercomm is not None - + def is_inuse(self): return self._is_inuse - + def is_polling_supported(self): return True - + def __getstate__(self): - return {'state':'empty'} - + return {"state": "empty"} + def __setstate__(self, state): self.info = MPI.INFO_NULL self.cached = None @@ -1260,19 +1432,23 @@ def job_scheduler(self): """Name of the job scheduler to use when starting the code, if given will use job scheduler to find list of hostnames for spawning""" return "" - def get_info_from_job_scheduler(self, name, number_of_workers = 1): + def get_info_from_job_scheduler(self, name, number_of_workers=1): if name == "slurm": return self.get_info_from_slurm(number_of_workers) return MPI.INFO_NULL @classmethod def get_info_from_slurm(cls, number_of_workers): - has_slurm_env_variables = 'SLURM_NODELIST' in os.environ and 'SLURM_TASKS_PER_NODE' in os.environ + has_slurm_env_variables = ( + "SLURM_NODELIST" in os.environ and "SLURM_TASKS_PER_NODE" in os.environ + ) if not has_slurm_env_variables: return MPI.INFO_NULL if not cls._scheduler_initialized: - nodelist = slurm.parse_slurm_nodelist(os.environ['SLURM_NODELIST']) - tasks_per_node = slurm.parse_slurm_tasks_per_node(os.environ['SLURM_TASKS_PER_NODE']) + nodelist = slurm.parse_slurm_nodelist(os.environ["SLURM_NODELIST"]) + tasks_per_node = slurm.parse_slurm_tasks_per_node( + os.environ["SLURM_TASKS_PER_NODE"] + ) all_nodes = [] for node, tasks in zip(nodelist, tasks_per_node): for _ in range(tasks): @@ -1284,13 +1460,13 @@ def get_info_from_slurm(cls, number_of_workers): hostnames = [] count = 0 while count < number_of_workers: - hostnames.append(cls._scheduler_nodes[cls._scheduler_index]) - count += 1 - cls._scheduler_index += 1 - if cls._scheduler_index >= len(cls._scheduler_nodes): - cls._scheduler_index = 0 - host = ','.join(hostnames) - print("HOST:", host, cls._scheduler_index, os.environ['SLURM_TASKS_PER_NODE']) + hostnames.append(cls._scheduler_nodes[cls._scheduler_index]) + count += 1 + cls._scheduler_index += 1 + if cls._scheduler_index >= len(cls._scheduler_nodes): + cls._scheduler_index = 0 + host = ",".join(hostnames) + print("HOST:", host, cls._scheduler_index, os.environ["SLURM_TASKS_PER_NODE"]) info = MPI.Info.Create() # actually in mpich and openmpi, the host parameter is interpreted as a @@ -1299,17 +1475,16 @@ def get_info_from_slurm(cls, number_of_workers): return info - class MultiprocessingMPIChannel(AbstractMessageChannel): """ - Message channel based on JSON messages. - + Message channel based on JSON messages. + The remote party functions as a message forwarder. Each message is forwarded to a real application using MPI. This is message channel is a lot slower than the MPI message channel. But, it is useful during testing with the MPICH2 nemesis channel. As the tests will run as one - application on one node they will cause oversaturation + application on one node they will cause oversaturation of the processor(s) on the node. Each legacy code will call the MPI_FINALIZE call and this call will wait for the MPI_FINALIZE call of the main test process. During @@ -1318,52 +1493,63 @@ class MultiprocessingMPIChannel(AbstractMessageChannel): instead of the normal MPIChannel. Then, part of the test is performed in a separate application (at least as MPI sees it) and this part can be stopped after each - sub-test, thus removing unneeded applications. + sub-test, thus removing unneeded applications. """ - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, **options): + + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + **options, + ): AbstractMessageChannel.__init__(self, **options) - + self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + if not legacy_interface_type is None: - self.full_name_of_the_worker = self.get_full_name_of_the_worker(legacy_interface_type) + self.full_name_of_the_worker = self.get_full_name_of_the_worker( + legacy_interface_type + ) else: self.full_name_of_the_worker = self.name_of_the_worker - + self.process = None - + @option(type="boolean") def debug_with_gdb(self): return False - + @option def hostname(self): return None - + def start(self): - name_of_dir = "/tmp/amuse_" + os.getenv('USER') - self.name_of_the_socket, self.server_socket = self._createAServerUNIXSocket(name_of_dir) + name_of_dir = "/tmp/amuse_" + os.getenv("USER") + self.name_of_the_socket, self.server_socket = self._createAServerUNIXSocket( + name_of_dir + ) environment = os.environ.copy() - - if 'PYTHONPATH' in environment: - environment['PYTHONPATH'] = environment['PYTHONPATH'] + ':' + self._extra_path_item(__file__) + + if "PYTHONPATH" in environment: + environment["PYTHONPATH"] = ( + environment["PYTHONPATH"] + ":" + self._extra_path_item(__file__) + ) else: - environment['PYTHONPATH'] = self._extra_path_item(__file__) - - + environment["PYTHONPATH"] = self._extra_path_item(__file__) + all_options = {} for x in self.iter_options(): all_options[x.name] = getattr(self, x.name) - - + template = """from {3} import {4} o = {1!r} m = channel.MultiprocessingMPIChannel('{0}',**o) m.run_mpi_channel('{2}')""" modulename = type(self).__module__ - packagagename, thismodulename = modulename.rsplit('.', 1) - + packagagename, thismodulename = modulename.rsplit(".", 1) + code_string = template.format( self.full_name_of_the_worker, all_options, @@ -1373,19 +1559,25 @@ def start(self): ) self.process = Popen([sys.executable, "-c", code_string], env=environment) self.client_socket, undef = self.server_socket.accept() - + def is_active(self): return self.process is not None - + def stop(self): - self._send(self.client_socket, ('stop', (),)) - result = self._recv(self.client_socket) + self._send( + self.client_socket, + ( + "stop", + (), + ), + ) + result = self._recv(self.client_socket) self.process.wait() self.client_socket.close() self.server_socket.close() self._remove_socket(self.name_of_the_socket) self.process = None - + def run_mpi_channel(self, name_of_the_socket): channel = MpiChannel(self.full_name_of_the_worker, **self._local_options) channel.start() @@ -1395,39 +1587,55 @@ def run_mpi_channel(self, name_of_the_socket): while is_running: message, args = self._recv(socket) result = None - if message == 'stop': + if message == "stop": channel.stop() is_running = False - if message == 'send_message': + if message == "send_message": result = channel.send_message(*args) - if message == 'recv_message': + if message == "recv_message": result = channel.recv_message(*args) self._send(socket, result) finally: socket.close() - - def send_message(self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units = ()): - self._send(self.client_socket, ('send_message', (call_id, function_id, dtype_to_arguments),)) + + def send_message( + self, call_id=0, function_id=-1, dtype_to_arguments={}, encoded_units=() + ): + self._send( + self.client_socket, + ( + "send_message", + (call_id, function_id, dtype_to_arguments), + ), + ) result = self._recv(self.client_socket) return result - def recv_message(self, call_id=0, function_id=-1, handle_as_array=False, has_units=False): - self._send(self.client_socket, ('recv_message', (call_id, function_id, handle_as_array),)) - result = self._recv(self.client_socket) + def recv_message( + self, call_id=0, function_id=-1, handle_as_array=False, has_units=False + ): + self._send( + self.client_socket, + ( + "recv_message", + (call_id, function_id, handle_as_array), + ), + ) + result = self._recv(self.client_socket) return result - + def _send(self, client_socket, message): message_string = pickle.dumps(message) header = struct.pack("i", len(message_string)) client_socket.sendall(header) client_socket.sendall(message_string) - + def _recv(self, client_socket): header = self._receive_all(client_socket, 4) length = struct.unpack("i", header) message_string = self._receive_all(client_socket, length[0]) return pickle.loads(message_string) - + def _receive_all(self, client_socket, number_of_bytes): block_size = 4096 bytes_left = number_of_bytes @@ -1439,18 +1647,17 @@ def _receive_all(self, client_socket, number_of_bytes): blocks.append(block) bytes_left -= len(block) return bytearray().join(blocks) - - + def _createAServerUNIXSocket(self, name_of_the_directory, name_of_the_socket=None): import uuid import socket - + if name_of_the_socket == None: name_of_the_socket = os.path.join(name_of_the_directory, str(uuid.uuid1())) - + if not os.path.exists(name_of_the_directory): os.makedirs(name_of_the_directory) - + server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._remove_socket(name_of_the_socket) server_socket.bind(name_of_the_socket) @@ -1459,19 +1666,20 @@ def _createAServerUNIXSocket(self, name_of_the_directory, name_of_the_socket=Non def _createAClientUNIXSocket(self, name_of_the_socket): import socket + client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - # client_socket.settimeout(0)header + # client_socket.settimeout(0)header client_socket.connect(name_of_the_socket) return client_socket - + def _remove_socket(self, name_of_the_socket): try: os.remove(name_of_the_socket) except OSError: pass - + def _extra_path_item(self, path_of_the_module): - result = '' + result = "" for x in sys.path: if path_of_the_module.startswith(x): if len(x) > len(result): @@ -1491,53 +1699,55 @@ def check_mpi(self): class SocketMessage(AbstractMessage): def _receive_all(self, nbytes, thesocket): # logger.debug("receiving %d bytes", nbytes) - + result = [] - + while nbytes > 0: chunk = min(nbytes, 10240) data_bytes = thesocket.recv(chunk) - + if len(data_bytes) == 0: raise exceptions.CodeException("lost connection to code") - + result.append(data_bytes) nbytes -= len(data_bytes) # logger.debug("got %d bytes, result length = %d", len(data_bytes), len(result)) - + if len(result) > 0: return type(result[0])().join(result) else: return b"" - + def receive(self, socket): # logger.debug("receiving message") - + header_bytes = self._receive_all(44, socket) - + flags = numpy.frombuffer(header_bytes, dtype="b", count=4, offset=0) - + if flags[0] != self.big_endian: - raise exceptions.CodeException("endianness in message does not match native endianness") - + raise exceptions.CodeException( + "endianness in message does not match native endianness" + ) + if flags[1]: self.error = True else: self.error = False - + header = numpy.copy(numpy.frombuffer(header_bytes, dtype="i", offset=0)) - + # logger.debug("receiving message with flags %s and header %s", flags, header) # id of this call self.call_id = header[1] - + # function ID self.function_id = header[2] - + # number of calls in this message self.call_count = header[3] - + # number of X's in TOTAL number_of_ints = header[4] number_of_longs = header[5] @@ -1554,96 +1764,90 @@ def receive(self, socket): self.booleans = self.receive_booleans(socket, number_of_booleans) self.strings = self.receive_strings(socket, number_of_strings) self.encoded_units = self.receive_doubles(socket, number_of_units) - + # logger.debug("message received") - def receive_ints(self, socket, count): if count > 0: nbytes = count * 4 # size of int - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='int32')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="int32")) + return result else: - return [] - + return [] + def receive_longs(self, socket, count): if count > 0: nbytes = count * 8 # size of long - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='int64')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="int64")) + return result else: return [] - - + def receive_floats(self, socket, count): if count > 0: nbytes = count * 4 # size of float - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='f4')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="f4")) + return result else: return [] - - + def receive_doubles(self, socket, count): if count > 0: nbytes = count * 8 # size of double - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='f8')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="f8")) + return result else: return [] - def receive_booleans(self, socket, count): if count > 0: nbytes = count * 1 # size of boolean/byte - + data_bytes = self._receive_all(nbytes, socket) - - result = numpy.copy(numpy.frombuffer(data_bytes, dtype='b')) - + + result = numpy.copy(numpy.frombuffer(data_bytes, dtype="b")) + return result else: return [] - - + def receive_strings(self, socket, count): if count > 0: lengths = self.receive_ints(socket, count) - + total = lengths.sum() + len(lengths) - + data_bytes = self._receive_all(total, socket) strings = [] begin = 0 for size in lengths: - strings.append(data_bytes[begin:begin + size].decode('utf-8')) + strings.append(data_bytes[begin : begin + size].decode("utf-8")) begin = begin + size + 1 return numpy.array(strings) else: return [] - + def nonblocking_receive(self, socket): return async_request.ASyncSocketRequest(self, socket) - - + def send(self, socket): flags = numpy.array( [self.big_endian, self.error, len(self.encoded_units) > 0, False], dtype="b" @@ -1666,7 +1870,7 @@ def send(self, socket): ) # logger.debug("sending message with flags %s and header %s", flags, header) - + socket.sendall(flags.tobytes()) socket.sendall(header.tobytes()) @@ -1678,24 +1882,24 @@ def send(self, socket): self.send_booleans(socket, self.booleans) self.send_strings(socket, self.strings) self.send_doubles(socket, self.encoded_units) - + # logger.debug("message send") def send_doubles(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='f8') + data_buffer = numpy.array(array, dtype="f8") socket.sendall(data_buffer.tobytes()) - + def send_ints(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='int32') + data_buffer = numpy.array(array, dtype="int32") socket.sendall(data_buffer.tobytes()) - + def send_floats(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='f4') + data_buffer = numpy.array(array, dtype="f4") socket.sendall(data_buffer.tobytes()) - + def send_strings(self, socket, array): if len(array) > 0: lengths = numpy.array([len(s) for s in array], dtype="int32") @@ -1710,15 +1914,15 @@ def send_strings(self, socket, array): self.send_ints(socket, lengths) socket.sendall(chars) - + def send_booleans(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='b') + data_buffer = numpy.array(array, dtype="b") socket.sendall(data_buffer.tobytes()) def send_longs(self, socket, array): if len(array) > 0: - data_buffer = numpy.array(array, dtype='int64') + data_buffer = numpy.array(array, dtype="int64") socket.sendall(data_buffer.tobytes()) @@ -1732,78 +1936,86 @@ def __init__( **options, ): AbstractMessageChannel.__init__(self, **options) - - #logging.getLogger().setLevel(logging.DEBUG) - + + # logging.getLogger().setLevel(logging.DEBUG) + logger.debug("initializing SocketChannel with options %s", options) - + # self.name_of_the_worker = name_of_the_worker + "_sockets" self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + if self.hostname == None: - self.hostname="localhost" + self.hostname = "localhost" - if self.hostname not in ['localhost',socket.gethostname()]: - self.remote=True - self.must_check_if_worker_is_up_to_date=False + if self.hostname not in ["localhost", socket.gethostname()]: + self.remote = True + self.must_check_if_worker_is_up_to_date = False else: - self.remote=False - + self.remote = False + self.id = 0 - + if not legacy_interface_type is None: - self.full_name_of_the_worker = self.get_full_name_of_the_worker(legacy_interface_type) + self.full_name_of_the_worker = self.get_full_name_of_the_worker( + legacy_interface_type + ) else: self.full_name_of_the_worker = self.name_of_the_worker - + logger.debug("full name of worker is %s", self.full_name_of_the_worker) - + self._is_inuse = False self._communicated_splitted_message = False self.socket = None - - self.remote_env=remote_env + + self.remote_env = remote_env @option(sections=("channel",)) def mpiexec(self): """mpiexec with arguments""" if len(config.mpi.mpiexec): return config.mpi.mpiexec - return '' + return "" @option(sections=("channel",)) def mpiexec_number_of_workers_flag(self): """flag to use, so that the number of workers are defined""" - return '-n' + return "-n" @late def debugger_method(self): return self.DEBUGGERS[self.debugger] - + def accept_worker_connection(self, server_socket, process): - #wait for the worker to connect. check if the process is still running once in a while + # wait for the worker to connect. check if the process is still running once in a while for i in range(0, 60): - #logger.debug("accepting connection") + # logger.debug("accepting connection") try: server_socket.settimeout(1.0) return server_socket.accept() except socket.timeout: - #update and read returncode + # update and read returncode if process.poll() is not None: - raise exceptions.CodeException('could not connect to worker, worker process terminated') - #logger.error("worker not connecting, waiting...") - - raise exceptions.CodeException('worker still not started after 60 seconds') + raise exceptions.CodeException( + "could not connect to worker, worker process terminated" + ) + # logger.error("worker not connecting, waiting...") + + raise exceptions.CodeException("worker still not started after 60 seconds") - def generate_command_and_arguments(self,server_address,port): + def generate_command_and_arguments(self, server_address, port): arguments = [] - + if not self.debugger_method is None: - command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) + command, arguments = self.debugger_method( + self.full_name_of_the_worker, + self, + interpreter_executable=self.interpreter_executable, + ) else: if ( self.redirect_stdout_file == "none" @@ -1816,9 +2028,15 @@ def generate_command_and_arguments(self,server_address,port): command = self.interpreter_executable arguments = [self.full_name_of_the_worker] else: - command, arguments = self.REDIRECT(self.full_name_of_the_worker, self.redirect_stdout_file, self.redirect_stderr_file, command=self.python_exe_for_redirection, interpreter_executable=self.interpreter_executable) + command, arguments = self.REDIRECT( + self.full_name_of_the_worker, + self.redirect_stdout_file, + self.redirect_stderr_file, + command=self.python_exe_for_redirection, + interpreter_executable=self.interpreter_executable, + ) - #start arguments with command + # start arguments with command arguments.insert(0, command) if self.initialize_mpi and len(self.mpiexec) > 0: @@ -1829,69 +2047,85 @@ def generate_command_and_arguments(self,server_address,port): arguments[:0] = mpiexec command = mpiexec[0] - #append with port and hostname where the worker should connect + # append with port and hostname where the worker should connect arguments.append(port) - #hostname of this machine + # hostname of this machine arguments.append(server_address) - - #initialize MPI inside worker executable - arguments.append('true') + + # initialize MPI inside worker executable + arguments.append("true") else: - #append arguments with port and socket where the worker should connect + # append arguments with port and socket where the worker should connect arguments.append(port) - #local machine + # local machine arguments.append(server_address) - - #do not initialize MPI inside worker executable - arguments.append('false') - return command,arguments + # do not initialize MPI inside worker executable + arguments.append("false") + + return command, arguments def remote_env_string(self, hostname): if self.remote_env is None: - if hostname in self.remote_envs.keys(): - return "source "+self.remote_envs[hostname]+"\n" - else: - return "" + if hostname in self.remote_envs.keys(): + return "source " + self.remote_envs[hostname] + "\n" + else: + return "" else: return "source " + self.remote_env + "\n" def generate_remote_command_and_arguments(self, hostname, server_address, port): # get remote config - args=["ssh","-T", hostname] + args = ["ssh", "-T", hostname] + + command = ( + self.remote_env_string(self.hostname) + + "amusifier --get-amuse-config" + + "\n" + ) - command=self.remote_env_string(self.hostname)+ \ - "amusifier --get-amuse-config" +"\n" - - proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") - out,err=proc.communicate(command.encode()) + proc = Popen(args, stdout=PIPE, stdin=PIPE, executable="ssh") + out, err = proc.communicate(command.encode()) try: - remote_config=parse_configmk_lines(out.decode().split("\n"),"remote config at "+self.hostname ) + remote_config = parse_configmk_lines( + out.decode().split("\n"), "remote config at " + self.hostname + ) except: - raise Exception(f"failed getting remote config from {self.hostname} - please check remote_env argument ({self.remote_env})") + raise Exception( + f"failed getting remote config from {self.hostname} - please check remote_env argument ({self.remote_env})" + ) # get remote amuse package dir - command=self.remote_env_string(self.hostname)+ \ - "amusifier --get-amuse-package-dir" +"\n" - - proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") - out,err=proc.communicate(command.encode()) - - remote_package_dir=out.decode().strip(" \n\t") - local_package_dir=get_amuse_package_dir() - - mpiexec=remote_config["MPIEXEC"] - initialize_mpi=remote_config["MPI_ENABLED"] == 'yes' - run_command_redirected_file=run_command_redirected.__file__.replace(local_package_dir,remote_package_dir) - interpreter_executable=None if self.interpreter_executable==None else remote_config["PYTHON"] + command = ( + self.remote_env_string(self.hostname) + + "amusifier --get-amuse-package-dir" + + "\n" + ) + + proc = Popen(args, stdout=PIPE, stdin=PIPE, executable="ssh") + out, err = proc.communicate(command.encode()) + + remote_package_dir = out.decode().strip(" \n\t") + local_package_dir = get_amuse_package_dir() + + mpiexec = remote_config["MPIEXEC"] + initialize_mpi = remote_config["MPI_ENABLED"] == "yes" + run_command_redirected_file = run_command_redirected.__file__.replace( + local_package_dir, remote_package_dir + ) + interpreter_executable = ( + None if self.interpreter_executable == None else remote_config["PYTHON"] + ) # dynamic python workers? (should be send over) - full_name_of_the_worker=self.full_name_of_the_worker.replace(local_package_dir,remote_package_dir) - python_exe_for_redirection=remote_config["PYTHON"] + full_name_of_the_worker = self.full_name_of_the_worker.replace( + local_package_dir, remote_package_dir + ) + python_exe_for_redirection = remote_config["PYTHON"] if not self.debugger_method is None: raise Exception("remote socket channel debugging not yet supported") - #command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) + # command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) else: if ( self.redirect_stdout_file == "none" @@ -1904,12 +2138,16 @@ def generate_remote_command_and_arguments(self, hostname, server_address, port): command = interpreter_executable arguments = [full_name_of_the_worker] else: - command, arguments = self.REDIRECT(full_name_of_the_worker, self.redirect_stdout_file, - self.redirect_stderr_file, command=python_exe_for_redirection, - interpreter_executable=interpreter_executable, - run_command_redirected_file=run_command_redirected_file) + command, arguments = self.REDIRECT( + full_name_of_the_worker, + self.redirect_stdout_file, + self.redirect_stderr_file, + command=python_exe_for_redirection, + interpreter_executable=interpreter_executable, + run_command_redirected_file=run_command_redirected_file, + ) - #start arguments with command + # start arguments with command arguments.insert(0, command) if initialize_mpi and len(mpiexec) > 0: @@ -1920,74 +2158,107 @@ def generate_remote_command_and_arguments(self, hostname, server_address, port): arguments[:0] = mpiexec command = mpiexec[0] - #append with port and hostname where the worker should connect + # append with port and hostname where the worker should connect arguments.append(port) - #hostname of this machine + # hostname of this machine arguments.append(server_address) - - #initialize MPI inside worker executable - arguments.append('true') + + # initialize MPI inside worker executable + arguments.append("true") else: - #append arguments with port and socket where the worker should connect + # append arguments with port and socket where the worker should connect arguments.append(port) - #local machine + # local machine arguments.append(server_address) - - #do not initialize MPI inside worker executable - arguments.append('false') - return command,arguments + # do not initialize MPI inside worker executable + arguments.append("false") + + return command, arguments def start(self): server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - server_address=self.get_host_ip(self.hostname) - - server_socket.bind((server_address , 0)) + + server_address = self.get_host_ip(self.hostname) + + server_socket.bind((server_address, 0)) server_socket.settimeout(1.0) server_socket.listen(1) - - logger.debug("starting socket worker process, listening for worker connection on %s", server_socket.getsockname()) - #this option set by CodeInterface + logger.debug( + "starting socket worker process, listening for worker connection on %s", + server_socket.getsockname(), + ) + + # this option set by CodeInterface logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) - - # set arguments to name of the worker, and port number we listen on + + # set arguments to name of the worker, and port number we listen on self.stdout = None self.stderr = None - + if self.remote: - command,arguments=self.generate_remote_command_and_arguments(self.hostname,server_address,str(server_socket.getsockname()[1])) + command, arguments = self.generate_remote_command_and_arguments( + self.hostname, server_address, str(server_socket.getsockname()[1]) + ) else: - command,arguments=self.generate_command_and_arguments(server_address,str(server_socket.getsockname()[1])) - + command, arguments = self.generate_command_and_arguments( + server_address, str(server_socket.getsockname()[1]) + ) + if self.remote: - logger.debug("starting remote process on %s with command `%s`, arguments `%s` and environment '%s'", self.hostname, command, arguments, os.environ) - ssh_command=self.remote_env_string(self.hostname)+" ".join(arguments) - arguments=["ssh","-T", self.hostname] - command="ssh" - self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) - self.process.stdin.write(ssh_command.encode()) - self.process.stdin.close() + logger.debug( + "starting remote process on %s with command `%s`, arguments `%s` and environment '%s'", + self.hostname, + command, + arguments, + os.environ, + ) + ssh_command = self.remote_env_string(self.hostname) + " ".join(arguments) + arguments = ["ssh", "-T", self.hostname] + command = "ssh" + self.process = Popen( + arguments, + executable=command, + stdin=PIPE, + stdout=None, + stderr=None, + close_fds=self.close_fds, + ) + self.process.stdin.write(ssh_command.encode()) + self.process.stdin.close() else: - logger.debug("starting process with command `%s`, arguments `%s` and environment '%s'", command, arguments, os.environ) - # ~ print(arguments) - self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) + logger.debug( + "starting process with command `%s`, arguments `%s` and environment '%s'", + command, + arguments, + os.environ, + ) + # ~ print(arguments) + self.process = Popen( + arguments, + executable=command, + stdin=PIPE, + stdout=None, + stderr=None, + close_fds=self.close_fds, + ) logger.debug("waiting for connection from worker") - self.socket, address = self.accept_worker_connection(server_socket, self.process) - + self.socket, address = self.accept_worker_connection( + server_socket, self.process + ) + self.socket.setblocking(1) - + self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - + server_socket.close() - + # logger.debug("got connection from %s", address) - + # logger.info("worker %s initialized", self.name_of_the_worker) - @option(type="boolean", sections=("sockets_channel",)) def close_fds(self): @@ -1996,52 +2267,52 @@ def close_fds(self): @option(type="dict", sections=("sockets_channel",)) def remote_envs(self): - """ dict of remote machine - enviroment (source ..) pairs """ + """dict of remote machine - enviroment (source ..) pairs""" return dict() @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" return "none" - + @option(sections=("channel",)) def hostname(self): return None - + def stop(self): - if (self.socket == None): + if self.socket == None: return - + logger.debug("stopping socket worker %s", self.name_of_the_worker) self.socket.close() - + self.socket = None if not self.process.stdin is None: self.process.stdin.close() - + # should lookinto using poll with a timeout or some other mechanism # when debugger method is on, no killing count = 0 - while(count < 5): + while count < 5: returncode = self.process.poll() if not returncode is None: break time.sleep(0.2) count += 1 - + if not self.stdout is None: self.stdout.close() - + if not self.stderr is None: self.stderr.close() def is_active(self): return self.socket is not None - + def is_inuse(self): return self._is_inuse - + def determine_length_from_datax(self, dtype_to_arguments): def get_length(type_and_values): argument_type, argument_values = type_and_values @@ -2054,98 +2325,134 @@ def get_length(type_and_values): except: result = max(result, 1) return result - - - + lengths = [get_length(x) for x in dtype_to_arguments.items()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) def send_message( self, call_id, function_id, dtype_to_arguments={}, encoded_units=() ): call_count = self.determine_length_from_data(dtype_to_arguments) - + # logger.info("sending message for call id %d, function %d, length %d", id, tag, length) - + if self.is_inuse(): - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) if self.socket is None: - raise exceptions.CodeException("You've tried to send a message to a code that is not running") - - + raise exceptions.CodeException( + "You've tried to send a message to a code that is not running" + ) + if call_count > self.max_message_length: - self.split_message(call_id, function_id, call_count, dtype_to_arguments, encoded_units) + self.split_message( + call_id, function_id, call_count, dtype_to_arguments, encoded_units + ) else: - message = SocketMessage(call_id, function_id, call_count, dtype_to_arguments, encoded_units = encoded_units) + message = SocketMessage( + call_id, + function_id, + call_count, + dtype_to_arguments, + encoded_units=encoded_units, + ) message.send(self.socket) self._is_inuse = True def recv_message(self, call_id, function_id, handle_as_array, has_units=False): self._is_inuse = False - + if self._communicated_splitted_message: x = self._merged_results_splitted_message self._communicated_splitted_message = False del self._merged_results_splitted_message return x - + message = SocketMessage() - + message.receive(self.socket) if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] if len(message.strings) > 0 else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." + self.stop() + error_message += " - code probably died, sorry." raise exceptions.CodeException("Error in code: " + error_message) if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units=False): + def nonblocking_recv_message( + self, call_id, function_id, handle_as_array, has_units=False + ): request = SocketMessage().nonblocking_receive(self.socket) - + def handle_result(function): self._is_inuse = False - + message = function() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] + if len(message.strings) > 0 + else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." - raise exceptions.CodeException("Error in (asynchronous) communication with worker: " + error_message) - + self.stop() + error_message += " - code probably died, sorry." + raise exceptions.CodeException( + "Error in (asynchronous) communication with worker: " + + error_message + ) + if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) - + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) + if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) request.add_result_handler(handle_result) - + return request @option(type="int", sections=("channel",)) @@ -2153,29 +2460,29 @@ def max_message_length(self): """ For calls to functions that can handle arrays, MPI messages may get too long for large N. The MPI channel will split long messages into blocks of size max_message_length. - """ + """ return 1000000 - def sanitize_host(self,hostname): + def sanitize_host(self, hostname): if "@" in hostname: - return hostname.split("@")[1] + return hostname.split("@")[1] return hostname - + def get_host_ip(self, client): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((self.sanitize_host(client), 80)) - ip=s.getsockname()[0] + ip = s.getsockname()[0] s.close() return ip - def makedirs(self,directory): - if self.remote: - args=["ssh","-T", self.hostname] - command=f"mkdir -p {directory}\n" - proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") - out,err=proc.communicate(command.encode()) - else: - os.makedirs(directory) + def makedirs(self, directory): + if self.remote: + args = ["ssh", "-T", self.hostname] + command = f"mkdir -p {directory}\n" + proc = Popen(args, stdout=PIPE, stdin=PIPE, executable="ssh") + out, err = proc.communicate(command.encode()) + else: + os.makedirs(directory) class OutputHandler(threading.Thread): @@ -2184,207 +2491,256 @@ def __init__(self, stream, port): self.stream = stream logger.debug("output handler connecting to daemon at %d", port) - + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - address = ('localhost', port) - + + address = ("localhost", port) + try: self.socket.connect(address) except: - raise exceptions.CodeException("Could not connect to Distributed Daemon at " + str(address)) - + raise exceptions.CodeException( + "Could not connect to Distributed Daemon at " + str(address) + ) + self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - - self.socket.sendall('TYPE_OUTPUT'.encode('utf-8')) + + self.socket.sendall("TYPE_OUTPUT".encode("utf-8")) # fetch ID of this connection - + result = SocketMessage() result.receive(self.socket) - + self.id = result.strings[0] - + logger.debug("output handler successfully connected to daemon at %d", port) self.daemon = True self.start() - + def run(self): while True: # logger.debug("receiving data for output") data = self.socket.recv(1024) - + if len(data) == 0: # logger.debug("end of output", len(data)) return - + # logger.debug("got %d bytes", len(data)) - + self.stream.write(data) class DistributedChannel(AbstractMessageChannel): default_distributed_instance = None - + @staticmethod def getStdoutID(instance): if not hasattr(instance, "_stdoutHandler") or instance._stdoutHandler is None: instance._stdoutHandler = OutputHandler(sys.stdout, instance.port) - + return instance._stdoutHandler.id - + @staticmethod def getStderrID(instance): if not hasattr(instance, "_stderrHandler") or instance._stderrHandler is None: instance._stderrHandler = OutputHandler(sys.stderr, instance.port) - + return instance._stderrHandler.id - - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, - distributed_instance=None, dynamic_python_code=False, **options): + + def __init__( + self, + name_of_the_worker, + legacy_interface_type=None, + interpreter_executable=None, + distributed_instance=None, + dynamic_python_code=False, + **options, + ): AbstractMessageChannel.__init__(self, **options) - + self._is_inuse = False self._communicated_splitted_message = False - + if distributed_instance is None: if self.default_distributed_instance is None: - raise Exception("No default distributed instance present, and none explicitly passed to code") + raise Exception( + "No default distributed instance present, and none explicitly passed to code" + ) self.distributed_instance = self.default_distributed_instance else: self.distributed_instance = distributed_instance - - #logger.setLevel(logging.DEBUG) - + + # logger.setLevel(logging.DEBUG) + logger.info("initializing DistributedChannel with options %s", options) - - self.socket=None - + + self.socket = None + self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - + self.dynamic_python_code = dynamic_python_code - + if self.number_of_workers == 0: self.number_of_workers = 1 - + if self.label == None: self.label = "" - - logger.debug("number of workers is %d, number of threads is %s, label is %s", self.number_of_workers, self.number_of_threads, self.label) - - self.daemon_host = 'localhost' # Distributed process always running on the local machine - self.daemon_port = self.distributed_instance.port # Port number for the Distributed process + + logger.debug( + "number of workers is %d, number of threads is %s, label is %s", + self.number_of_workers, + self.number_of_threads, + self.label, + ) + + self.daemon_host = ( + "localhost" # Distributed process always running on the local machine + ) + self.daemon_port = ( + self.distributed_instance.port + ) # Port number for the Distributed process logger.debug("port is %d", self.daemon_port) - + self.id = 0 - + if not legacy_interface_type is None: # worker specified by type. Figure out where this file is # mostly (only?) used by dynamic python codes - directory_of_this_module = os.path.dirname(inspect.getfile(legacy_interface_type)) - worker_path = os.path.join(directory_of_this_module, self.name_of_the_worker) - self.full_name_of_the_worker = os.path.normpath(os.path.abspath(worker_path)) - + directory_of_this_module = os.path.dirname( + inspect.getfile(legacy_interface_type) + ) + worker_path = os.path.join( + directory_of_this_module, self.name_of_the_worker + ) + self.full_name_of_the_worker = os.path.normpath( + os.path.abspath(worker_path) + ) + self.name_of_the_worker = os.path.basename(self.full_name_of_the_worker) - + else: # worker specified by executable (usually already absolute) - self.full_name_of_the_worker = os.path.normpath(os.path.abspath(self.name_of_the_worker)) - + self.full_name_of_the_worker = os.path.normpath( + os.path.abspath(self.name_of_the_worker) + ) + global_options = GlobalOptions() - - self.executable = os.path.relpath(self.full_name_of_the_worker, global_options.amuse_rootdirectory) - + + self.executable = os.path.relpath( + self.full_name_of_the_worker, global_options.amuse_rootdirectory + ) + self.worker_dir = os.path.dirname(self.full_name_of_the_worker) - + logger.debug("executable is %s", self.executable) logger.debug("full name of the worker is %s", self.full_name_of_the_worker) - + logger.debug("worker dir is %s", self.worker_dir) - + self._is_inuse = False def check_if_worker_is_up_to_date(self, object): -# if self.hostname != 'localhost': -# return -# -# logger.debug("hostname = %s, checking for worker", self.hostname) -# -# AbstractMessageChannel.check_if_worker_is_up_to_date(self, object) - + # if self.hostname != 'localhost': + # return + # + # logger.debug("hostname = %s, checking for worker", self.hostname) + # + # AbstractMessageChannel.check_if_worker_is_up_to_date(self, object) + pass - + def start(self): logger.debug("connecting to daemon") - + # if redirect = none, set output file to console stdout stream ID, otherwise make absolute - if (self.redirect_stdout_file == 'none'): + if self.redirect_stdout_file == "none": self.redirect_stdout_file = self.getStdoutID(self.distributed_instance) else: self.redirect_stdout_file = os.path.abspath(self.redirect_stdout_file) # if redirect = none, set error file to console stderr stream ID, otherwise make absolute - if (self.redirect_stderr_file == 'none'): + if self.redirect_stderr_file == "none": self.redirect_stderr_file = self.getStderrID(self.distributed_instance) else: self.redirect_stderr_file = os.path.abspath(self.redirect_stderr_file) - + logger.debug("output send to = " + self.redirect_stdout_file) - + logger.debug("error send to = " + self.redirect_stderr_file) - + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: self.socket.connect((self.daemon_host, self.daemon_port)) except: self.socket = None - raise exceptions.CodeException("Could not connect to Ibis Daemon at " + str(self.daemon_port)) - + raise exceptions.CodeException( + "Could not connect to Ibis Daemon at " + str(self.daemon_port) + ) + self.socket.setblocking(1) - + self.socket.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - - self.socket.sendall('TYPE_WORKER'.encode('utf-8')) - - arguments = {'string': [self.executable, self.redirect_stdout_file, self.redirect_stderr_file, self.label, self.worker_dir], 'int32': [self.number_of_workers, self.number_of_threads], 'bool': [ self.dynamic_python_code]} - - message = SocketMessage(call_id=1, function_id=10101010, call_count=1, dtype_to_arguments=arguments) + + self.socket.sendall("TYPE_WORKER".encode("utf-8")) + + arguments = { + "string": [ + self.executable, + self.redirect_stdout_file, + self.redirect_stderr_file, + self.label, + self.worker_dir, + ], + "int32": [self.number_of_workers, self.number_of_threads], + "bool": [self.dynamic_python_code], + } + + message = SocketMessage( + call_id=1, function_id=10101010, call_count=1, dtype_to_arguments=arguments + ) message.send(self.socket) - + logger.info("waiting for worker %s to be initialized", self.name_of_the_worker) result = SocketMessage() result.receive(self.socket) - + if result.error: logger.error("Could not start worker: %s", result.strings[0]) self.stop() - raise exceptions.CodeException("Could not start worker for " + self.name_of_the_worker + ": " + result.strings[0]) - + raise exceptions.CodeException( + "Could not start worker for " + + self.name_of_the_worker + + ": " + + result.strings[0] + ) + self.remote_amuse_dir = result.strings[0] - + logger.info("worker %s initialized", self.name_of_the_worker) logger.info("worker remote amuse dir = %s", self.remote_amuse_dir) - + @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" return "none" - + def get_amuse_root_directory(self): return self.remote_amuse_dir - + @option(type="int", sections=("channel",)) def number_of_threads(self): return 0 - + @option(type="string", sections=("channel",)) def label(self): return None - + def stop(self): if self.socket is not None: logger.info("stopping worker %s", self.name_of_the_worker) @@ -2392,11 +2748,11 @@ def stop(self): self.socket = None def is_active(self): - return self.socket is not None - + return self.socket is not None + def is_inuse(self): return self._is_inuse - + def determine_length_from_datax(self, dtype_to_arguments): def get_length(x): if x: @@ -2405,104 +2761,132 @@ def get_length(x): return len(x[0]) except: return 1 - - - + lengths = [get_length(x) for x in dtype_to_arguments.values()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) def send_message( self, call_id, function_id, dtype_to_arguments={}, encoded_units=None ): call_count = self.determine_length_from_data(dtype_to_arguments) - - logger.debug("sending message for call id %d, function %d, length %d", call_id, function_id, call_count) - + + logger.debug( + "sending message for call id %d, function %d, length %d", + call_id, + function_id, + call_count, + ) + if self.is_inuse(): - raise exceptions.CodeException("You've tried to send a message to a code that is already handling a message, this is not correct") + raise exceptions.CodeException( + "You've tried to send a message to a code that is already handling a message, this is not correct" + ) if self.socket is None: - raise exceptions.CodeException("You've tried to send a message to a code that is not running") - + raise exceptions.CodeException( + "You've tried to send a message to a code that is not running" + ) + if call_count > self.max_message_length: - self.split_message(call_id, function_id, call_count, dtype_to_arguments, encoded_units) + self.split_message( + call_id, function_id, call_count, dtype_to_arguments, encoded_units + ) else: - message = SocketMessage(call_id, function_id, call_count, dtype_to_arguments, False, False) + message = SocketMessage( + call_id, function_id, call_count, dtype_to_arguments, False, False + ) message.send(self.socket) self._is_inuse = True - def recv_message(self, call_id, function_id, handle_as_array, has_units=False): self._is_inuse = False - + if self._communicated_splitted_message: x = self._merged_results_splitted_message self._communicated_splitted_message = False del self._merged_results_splitted_message return x - + message = SocketMessage() - + message.receive(self.socket) if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] if len(message.strings) > 0 else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - # self.stop() - error_message+=" - code probably died, sorry." + # self.stop() + error_message += " - code probably died, sorry." raise exceptions.CodeException("Error in worker: " + error_message) if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units=False): + def nonblocking_recv_message( + self, call_id, function_id, handle_as_array, has_units=False + ): # raise exceptions.CodeException("Nonblocking receive not supported by DistributedChannel") request = SocketMessage().nonblocking_receive(self.socket) - + def handle_result(function): self._is_inuse = False - + message = function() if message.error: - error_message=message.strings[0] if len(message.strings)>0 else "no error message" + error_message = ( + message.strings[0] + if len(message.strings) > 0 + else "no error message" + ) if message.call_id != call_id or message.function_id != function_id: - self.stop() - error_message+=" - code probably died, sorry." - raise exceptions.CodeException("Error in (asynchronous) communication with worker: " + error_message) - + self.stop() + error_message += " - code probably died, sorry." + raise exceptions.CodeException( + "Error in (asynchronous) communication with worker: " + + error_message + ) + if message.call_id != call_id: self.stop() - raise exceptions.CodeException('Received reply for call id {0} but expected {1}'.format(message.call_id, call_id)) - + raise exceptions.CodeException( + "Received reply for call id {0} but expected {1}".format( + message.call_id, call_id + ) + ) + if message.function_id != function_id: self.stop() - raise exceptions.CodeException('Received reply for function id {0} but expected {1}'.format(message.function_id, function_id)) - + raise exceptions.CodeException( + "Received reply for function id {0} but expected {1}".format( + message.function_id, function_id + ) + ) + if has_units: return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) request.add_result_handler(handle_result) - + return request - + @option(type="int", sections=("channel",)) def max_message_length(self): """ For calls to functions that can handle arrays, MPI messages may get too long for large N. The MPI channel will split long messages into blocks of size max_message_length. - """ + """ return 1000000 + class LocalChannel(AbstractMessageChannel): def __init__( self, @@ -2518,41 +2902,40 @@ def __init__( if not legacy_interface_type is None: self.so_module = legacy_interface_type.__so_module__ - self.package, _ = legacy_interface_type.__module__.rsplit('.',1) + self.package, _ = legacy_interface_type.__module__.rsplit(".", 1) else: - raise Exception("Need to give the legacy interface type for the local channel") - + raise Exception( + "Need to give the legacy interface type for the local channel" + ) + self.legacy_interface_type = legacy_interface_type self._is_inuse = False self.module = None - - - def check_if_worker_is_up_to_date(self, object): pass - + def start(self): from . import import_module from . import python_code - + module = import_module.import_unique(self.package + "." + self.so_module) print(module, self.package + "." + self.so_module) module.set_comm_world(MPI.COMM_SELF) - self.local_implementation = python_code.CythonImplementation(module, self.legacy_interface_type) + self.local_implementation = python_code.CythonImplementation( + module, self.legacy_interface_type + ) self.module = module - - def stop(self): from . import import_module + import_module.cleanup_module(self.module) self.module = None - def is_active(self): return not self.module is None - + def is_inuse(self): return self._is_inuse @@ -2560,29 +2943,32 @@ def send_message( self, call_id, function_id, dtype_to_arguments={}, encoded_units=None ): call_count = self.determine_length_from_data(dtype_to_arguments) - - self.message = LocalMessage(call_id, function_id, call_count, dtype_to_arguments, encoded_units = encoded_units) - self.is_inuse = True - + self.message = LocalMessage( + call_id, + function_id, + call_count, + dtype_to_arguments, + encoded_units=encoded_units, + ) + self.is_inuse = True def recv_message(self, call_id, function_id, handle_as_array, has_units=False): output_message = LocalMessage(call_id, function_id, self.message.call_count) self.local_implementation.handle_message(self.message, output_message) - + if has_units: - return output_message.to_result(handle_as_array),output_message.encoded_units + return ( + output_message.to_result(handle_as_array), + output_message.encoded_units, + ) else: return output_message.to_result(handle_as_array) - - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array): pass - def determine_length_from_datax(self, dtype_to_arguments): def get_length(x): if x: @@ -2592,19 +2978,16 @@ def get_length(x): except: return 1 return 1 - - - + lengths = [get_length(x) for x in dtype_to_arguments.values()] if len(lengths) == 0: return 1 - + return max(1, max(lengths)) - - def is_polling_supported(self): return False + class LocalMessage(AbstractMessage): pass From 6957c42f04354f3fa60253f7a00a1c7e523b3081 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Mon, 14 Oct 2024 10:54:26 +0200 Subject: [PATCH 11/12] revised with Black and manual updates --- src/amuse/rfi/nospawn.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/amuse/rfi/nospawn.py b/src/amuse/rfi/nospawn.py index da3733ec15..9c6eeb950a 100644 --- a/src/amuse/rfi/nospawn.py +++ b/src/amuse/rfi/nospawn.py @@ -1,10 +1,9 @@ -from amuse.rfi import core -from amuse.rfi.python_code import CythonImplementation +import importlib +from collections import namedtuple from mpi4py import MPI +from amuse.rfi import core from amuse.rfi import channel -from collections import namedtuple -import sys -import importlib +from amuse.rfi.python_code import CythonImplementation Code = namedtuple("Code", ["cls", "number_of_workers", "args", "kwargs"]) PythonCode = namedtuple( @@ -64,9 +63,9 @@ def start_all(codes): if world.size < number_of_workers_needed: if rank == 0: raise Exception( - "cannot start all codes, the world size ({0}) is smaller than the number of requested codes ({1}) (which is always 1 + the sum of the all the number_of_worker fields)".format( - world.size, number_of_workers_needed - ) + f"Cannot start all codes, the world size ({world.size}) is smaller " + f"than the number of requested codes ({number_of_workers_needed}) " + f"(which is always 1 + the sum of the all the number_of_worker fields)" ) else: return None @@ -136,12 +135,12 @@ def start_empty(): world = MPI.COMM_WORLD rank = world.rank - color = 0 if world.rank == 0 else 1 - key = 0 if world.rank == 0 else world.rank - 1 + color = 0 if rank == 0 else 1 + key = 0 if rank == 0 else rank - 1 newcomm = world.Split(color, key) localdup = world.Dup() - if world.rank == 0: + if rank == 0: result = [] remote_leader = 1 tag = 1 @@ -167,17 +166,16 @@ def start_empty(): instance.must_disconnect = False world.Barrier() instance.start() - print("STOP...", world.rank) + print(f"STOP... {world.rank}") return None def get_code(rank, codes): if rank == 0: return None - else: - index = 1 - for color, x in enumerate(codes): - if rank >= index and rank < index + x.number_of_workers: - return x - index += x.number_of_workers + index = 1 + for color, x in enumerate(codes): + if rank >= index and rank < index + x.number_of_workers: + return x + index += x.number_of_workers return None From cc5513747580dd3e41427ec85b83890841a39859 Mon Sep 17 00:00:00 2001 From: Steven Rieder Date: Mon, 14 Oct 2024 10:55:06 +0200 Subject: [PATCH 12/12] not is None -> is not None --- src/amuse/rfi/run_command_redirected.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/amuse/rfi/run_command_redirected.py b/src/amuse/rfi/run_command_redirected.py index 3f2f0e7257..2d1623846a 100644 --- a/src/amuse/rfi/run_command_redirected.py +++ b/src/amuse/rfi/run_command_redirected.py @@ -41,10 +41,10 @@ def translate_filename_for_os(filename): stdin.close() - if not stdout is None: + if stdout is not None: stdout.close() - if not stderr is None: + if stderr is not None: stderr.close() sys.exit(returncode)