From c1d30358afff15203a1ec327157752a578d713ee Mon Sep 17 00:00:00 2001 From: Christopher Rink Date: Sun, 31 Dec 2023 14:07:04 -0500 Subject: [PATCH 1/2] Hide the arity map in generated Python --- src/basilisp/lang/compiler/generator.py | 78 ++++++++++++++----------- src/basilisp/lang/runtime.py | 9 ++- 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/src/basilisp/lang/compiler/generator.py b/src/basilisp/lang/compiler/generator.py index df0a1451..500598e4 100644 --- a/src/basilisp/lang/compiler/generator.py +++ b/src/basilisp/lang/compiler/generator.py @@ -1648,35 +1648,47 @@ def __fn_args_to_py_ast( def __fn_decorator( arities: Iterable[int], has_rest_arg: bool = False, + arity_map: Optional[ast.Dict] = None, + bind_self_arg: Optional[bool] = None, ) -> ast.Call: return ast.Call( func=_BASILISP_FN_FN_NAME, args=[], - keywords=[ - ast.keyword( - arg="arities", - value=ast.Tuple( - elts=list( - chain( - map(ast.Constant, arities), - [ - ast.Call( - func=_NEW_KW_FN_NAME, - args=[ - ast.Constant(hash(_REST_KW)), - ast.Constant("rest"), - ], - keywords=[], + keywords=list( + chain( + [ + ast.keyword( + arg="arities", + value=ast.Tuple( + elts=list( + chain( + map(ast.Constant, arities), + [ + ast.Call( + func=_NEW_KW_FN_NAME, + args=[ + ast.Constant(hash(_REST_KW)), + ast.Constant("rest"), + ], + keywords=[], + ) + ] + if has_rest_arg + else [], ) - ] - if has_rest_arg - else [], - ) - ), - ctx=ast.Load(), - ), + ), + ctx=ast.Load(), + ), + ) + ], + [ast.keyword(arg="arity_map", value=arity_map)] + if arity_map is not None + else [], + [ast.keyword(arg="bind_self", value=ast.Constant(bind_self_arg))] + if bind_self_arg + else [], ) - ], + ), ) @@ -1827,8 +1839,6 @@ def fn(*args): return default(*args) raise RuntimeError """ - dispatch_map_name = f"{name}_dispatch_map" - dispatch_keys, dispatch_vals = [], [] for k, v in arity_map.items(): dispatch_keys.append(ast.Constant(k)) @@ -1852,7 +1862,11 @@ def fn(*args): targets=[ast.Name(id=arity_name, ctx=ast.Store())], value=ast.Call( func=ast.Attribute( - value=ast.Name(id=dispatch_map_name, ctx=ast.Load()), + value=ast.Attribute( + value=ast.Name(id="self", ctx=ast.Load()), + attr="arity_map", + ctx=ast.Load(), + ), attr="get", ctx=ast.Load(), ), @@ -1948,12 +1962,6 @@ def fn(*args): return GeneratedPyAST( node=ast.Name(id=name, ctx=ast.Load()), dependencies=chain( - [ - ast.Assign( - targets=[ast.Name(id=dispatch_map_name, ctx=ast.Store())], - value=ast.Dict(keys=dispatch_keys, values=dispatch_vals), - ) - ], meta_deps, ret_ann_deps, [ @@ -1961,7 +1969,7 @@ def fn(*args): name=name, args=ast.arguments( posonlyargs=[], - args=[], + args=[ast.arg(arg="self", annotation=None)], kwarg=None, vararg=ast.arg(arg=_MULTI_ARITY_ARG_NAME, annotation=None), kwonlyargs=[], @@ -1976,6 +1984,10 @@ def fn(*args): __fn_decorator( arity_map.keys(), has_rest_arg=default_name is not None, + arity_map=ast.Dict( + keys=dispatch_keys, values=dispatch_vals + ), + bind_self_arg=True, ) ], ) diff --git a/src/basilisp/lang/runtime.py b/src/basilisp/lang/runtime.py index 6440f868..2f9b32e9 100644 --- a/src/basilisp/lang/runtime.py +++ b/src/basilisp/lang/runtime.py @@ -1776,7 +1776,11 @@ def wrapped_f(*args, **kwargs): return wrapped_f -def _basilisp_fn(arities: Tuple[Union[int, kw.Keyword]]): +def _basilisp_fn( + arities: Tuple[Union[int, kw.Keyword]], + arity_map: Optional[Mapping[Union[int, kw.Keyword], Callable]] = None, + bind_self: bool = False, +): """Create a Basilisp function, setting meta and supplying a with_meta method implementation.""" @@ -1784,7 +1788,10 @@ def wrap_fn(f): assert not hasattr(f, "meta") f._basilisp_fn = True f.arities = lset.set(arities) + f.arity_map = lmap.map(arity_map) if arity_map is not None else None f.meta = None + if bind_self: + f = partial(f, f) f.with_meta = partial(_fn_with_meta, f) return f From d9ab9684a9208d5706f86038e5524ae9c9a95eed Mon Sep 17 00:00:00 2001 From: Christopher Rink Date: Wed, 3 Jan 2024 08:51:05 -0500 Subject: [PATCH 2/2] Try this --- src/basilisp/lang/compiler/generator.py | 216 ++++++++++-------------- src/basilisp/lang/runtime.py | 57 ++++++- 2 files changed, 146 insertions(+), 127 deletions(-) diff --git a/src/basilisp/lang/compiler/generator.py b/src/basilisp/lang/compiler/generator.py index 500598e4..0223e5ef 100644 --- a/src/basilisp/lang/compiler/generator.py +++ b/src/basilisp/lang/compiler/generator.py @@ -697,6 +697,9 @@ def _var_ns_as_python_sym(name: str) -> str: _COLLECT_ARGS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._collect_args") _COERCE_SEQ_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}.to_seq") _BASILISP_FN_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_fn") +_BASILISP_MULTI_ARITY_FN_FN_NAME = _load_attr( + f"{_RUNTIME_ALIAS}._basilisp_multi_arity_fn" +) _FN_WITH_ATTRS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._with_attrs") _BASILISP_TYPE_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_type") _BASILISP_WITH_META_INTERFACE_NAME = _load_attr(f"{_INTERFACES_ALIAS}.IWithMeta") @@ -1844,103 +1847,6 @@ def fn(*args): dispatch_keys.append(ast.Constant(k)) dispatch_vals.append(ast.Name(id=v, ctx=ast.Load())) - # Async functions should return await, otherwise just return - handle_return = __handle_async_return if is_async else __handle_return - - nargs_name = genname("nargs") - arity_name = genname("arity") - body = [ - ast.Assign( - targets=[ast.Name(id=nargs_name, ctx=ast.Store())], - value=ast.Call( - func=ast.Name(id="len", ctx=ast.Load()), - args=[ast.Name(id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load())], - keywords=[], - ), - ), - ast.Assign( - targets=[ast.Name(id=arity_name, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Attribute( - value=ast.Name(id="self", ctx=ast.Load()), - attr="arity_map", - ctx=ast.Load(), - ), - attr="get", - ctx=ast.Load(), - ), - args=[ast.Name(id=nargs_name, ctx=ast.Load())], - keywords=[], - ), - ), - ast.If( - test=ast.Compare( - left=ast.Constant(None), - ops=[ast.IsNot()], - comparators=[ast.Name(id=arity_name, ctx=ast.Load())], - ), - body=[ - handle_return( - ast.Call( - func=ast.Name(id=arity_name, ctx=ast.Load()), - args=[ - ast.Starred( - value=ast.Name( - id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load() - ), - ctx=ast.Load(), - ) - ], - keywords=[], - ) - ) - ], - orelse=[] - if default_name is None - else [ - ast.If( - test=ast.Compare( - left=ast.Name(id=nargs_name, ctx=ast.Load()), - ops=[ast.GtE()], - comparators=[ast.Constant(max_fixed_arity)], - ), - body=[ - handle_return( - ast.Call( - func=ast.Name(id=default_name, ctx=ast.Load()), - args=[ - ast.Starred( - value=ast.Name( - id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load() - ), - ctx=ast.Load(), - ) - ], - keywords=[], - ) - ) - ], - orelse=[], - ) - ], - ), - ast.Raise( - exc=ast.Call( - func=_load_attr("basilisp.lang.runtime.RuntimeException"), - args=[ - ast.Constant(f"Wrong number of args passed to function: {name}"), - ast.Name(id=nargs_name, ctx=ast.Load()), - ], - keywords=[], - ), - cause=None, - ), - ] - - py_fn_node = ast.AsyncFunctionDef if is_async else ast.FunctionDef - meta_deps, meta_decorators = __fn_meta(ctx, meta_node) - ret_ann_ast: Optional[ast.AST] = None ret_ann_deps: List[ast.AST] = [] if all(tag is not None for tag in return_tags): @@ -1959,40 +1865,103 @@ def fn(*args): else None ) + meta_py_ast: Optional[ast.AST] = None + meta_deps: List[ast.AST] = [] + if meta_node is not None: + meta_ast = gen_py_ast(ctx, meta_node) + meta_py_ast = meta_ast.node + meta_deps.extend(meta_ast.dependencies) + return GeneratedPyAST( node=ast.Name(id=name, ctx=ast.Load()), dependencies=chain( meta_deps, ret_ann_deps, [ - py_fn_node( - name=name, - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg="self", annotation=None)], - kwarg=None, - vararg=ast.arg(arg=_MULTI_ARITY_ARG_NAME, annotation=None), - kwonlyargs=[], - defaults=[], - kw_defaults=[], - ), - body=body, - decorator_list=list( - chain( - meta_decorators, - [ - __fn_decorator( - arity_map.keys(), - has_rest_arg=default_name is not None, - arity_map=ast.Dict( - keys=dispatch_keys, values=dispatch_vals - ), - bind_self_arg=True, + _tagged_assign( + target=ast.Name(id=name, ctx=ast.Store()), + value=ast.Call( + func=ast.Call( + func=_BASILISP_MULTI_ARITY_FN_FN_NAME, + args=[], + keywords=list( + chain( + [ + ast.keyword( + arg="arities", + value=ast.Tuple( + elts=list( + chain( + map( + ast.Constant, + arity_map.keys(), + ), + [ + ast.Call( + func=_NEW_KW_FN_NAME, + args=[ + ast.Constant( + hash(_REST_KW) + ), + ast.Constant( + "rest" + ), + ], + keywords=[], + ) + ] + if default_name is not None + else [], + ) + ), + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="arity_map", + value=ast.Dict( + keys=dispatch_keys, values=dispatch_vals + ), + ), + ], + [ + ast.keyword( + arg="max_fixed_arity", + value=ast.Constant(max_fixed_arity), + ), + ast.keyword( + arg="default", + value=ast.Name( + id=default_name, ctx=ast.Load() + ), + ), + ] + if default_name is not None + else [], ) - ], - ) + ), + ), + args=[ + ast.Call( + func=_NEW_SYM_FN_NAME, + args=[ast.Constant(name)], + keywords=[], + ), + meta_py_ast + if meta_py_ast is not None + else ast.Constant(None), + ], + keywords=[], ), - returns=ret_ann_ast, + annotation=ast.Subscript( + value=ast.Name(id="Callable", ctx=ast.Load()), + slice=ast.Tuple( + elts=[ast.Ellipsis(), ret_ann_ast], ctx=ast.Load() + ), + ctx=ast.Load(), + ) + if ret_ann_ast is not None + else None, ) ], ), @@ -3866,6 +3835,7 @@ def _from_module_imports() -> Iterable[ast.ImportFrom]: ast.ImportFrom( module="typing", names=[ + ast.alias(name="Callable", asname=None), ast.alias(name="Union", asname=None), ], level=0, diff --git a/src/basilisp/lang/runtime.py b/src/basilisp/lang/runtime.py index 2f9b32e9..27a8be1d 100644 --- a/src/basilisp/lang/runtime.py +++ b/src/basilisp/lang/runtime.py @@ -1777,9 +1777,8 @@ def wrapped_f(*args, **kwargs): def _basilisp_fn( - arities: Tuple[Union[int, kw.Keyword]], + arities: Tuple[Union[int, kw.Keyword], ...], arity_map: Optional[Mapping[Union[int, kw.Keyword], Callable]] = None, - bind_self: bool = False, ): """Create a Basilisp function, setting meta and supplying a with_meta method implementation.""" @@ -1790,14 +1789,64 @@ def wrap_fn(f): f.arities = lset.set(arities) f.arity_map = lmap.map(arity_map) if arity_map is not None else None f.meta = None - if bind_self: - f = partial(f, f) f.with_meta = partial(_fn_with_meta, f) return f return wrap_fn +def _basilisp_multi_arity_fn( + arities: Tuple[Union[int, kw.Keyword], ...], + arity_map: Mapping[Union[int, kw.Keyword], Callable], + max_fixed_arity: Optional[int] = None, + default: Optional[Callable] = None, +) -> Type: + class BasilispMultiArityFn: + __slots__ = ("_name", "arities", "arity_map", "meta") + + def __init__( + self, + name: Optional[sym.Symbol], + meta: Optional[lmap.PersistentMap] = None, + ): + self._name = name + self.arities = lset.set(arities) + self.arity_map = lmap.map(arity_map) + self.meta = meta + + _basilisp_fn = True + + if max_fixed_arity is not None: + assert default is not None + + def __call__(self, *args): + nargs = len(args) + arity = self.arity_map.get(nargs) + if arity is not None: + return arity(*args) + if nargs > max_fixed_arity: + return default(*args) + raise RuntimeException( + f"Wrong number of args passed to function: {self._name}", nargs + ) + + else: + + def __call__(self, *args): + nargs = len(args) + arity = self.arity_map.get(nargs) + if arity is not None: + return arity(*args) + raise RuntimeException( + f"Wrong number of args passed to function: {self._name}", nargs + ) + + def with_meta(self, meta: Optional[lmap.PersistentMap] = None): + return type(self)(None, meta) + + return BasilispMultiArityFn + + def _basilisp_type( fields: Iterable[str], interfaces: Iterable[Type],