Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hide the arity map in generated Python #766

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 128 additions & 146 deletions src/basilisp/lang/compiler/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ def _var_ns_as_python_sym(name: str) -> str:
_COLLECT_ARGS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._collect_args")
_COERCE_SEQ_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}.to_seq")
_BASILISP_FN_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_fn")
_BASILISP_MULTI_ARITY_FN_FN_NAME = _load_attr(
f"{_RUNTIME_ALIAS}._basilisp_multi_arity_fn"
)
_FN_WITH_ATTRS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._with_attrs")
_BASILISP_TYPE_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_type")
_BASILISP_WITH_META_INTERFACE_NAME = _load_attr(f"{_INTERFACES_ALIAS}.IWithMeta")
Expand Down Expand Up @@ -1648,35 +1651,47 @@ def __fn_args_to_py_ast(
def __fn_decorator(
arities: Iterable[int],
has_rest_arg: bool = False,
arity_map: Optional[ast.Dict] = None,
bind_self_arg: Optional[bool] = None,
) -> ast.Call:
return ast.Call(
func=_BASILISP_FN_FN_NAME,
args=[],
keywords=[
ast.keyword(
arg="arities",
value=ast.Tuple(
elts=list(
chain(
map(ast.Constant, arities),
[
ast.Call(
func=_NEW_KW_FN_NAME,
args=[
ast.Constant(hash(_REST_KW)),
ast.Constant("rest"),
],
keywords=[],
keywords=list(
chain(
[
ast.keyword(
arg="arities",
value=ast.Tuple(
elts=list(
chain(
map(ast.Constant, arities),
[
ast.Call(
func=_NEW_KW_FN_NAME,
args=[
ast.Constant(hash(_REST_KW)),
ast.Constant("rest"),
],
keywords=[],
)
]
if has_rest_arg
else [],
)
]
if has_rest_arg
else [],
)
),
ctx=ast.Load(),
),
),
ctx=ast.Load(),
),
)
],
[ast.keyword(arg="arity_map", value=arity_map)]
if arity_map is not None
else [],
[ast.keyword(arg="bind_self", value=ast.Constant(bind_self_arg))]
if bind_self_arg
else [],
)
],
),
)


Expand Down Expand Up @@ -1827,106 +1842,11 @@ def fn(*args):
return default(*args)
raise RuntimeError
"""
dispatch_map_name = f"{name}_dispatch_map"

dispatch_keys, dispatch_vals = [], []
for k, v in arity_map.items():
dispatch_keys.append(ast.Constant(k))
dispatch_vals.append(ast.Name(id=v, ctx=ast.Load()))

# Async functions should return await, otherwise just return
handle_return = __handle_async_return if is_async else __handle_return

nargs_name = genname("nargs")
arity_name = genname("arity")
body = [
ast.Assign(
targets=[ast.Name(id=nargs_name, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="len", ctx=ast.Load()),
args=[ast.Name(id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load())],
keywords=[],
),
),
ast.Assign(
targets=[ast.Name(id=arity_name, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=dispatch_map_name, ctx=ast.Load()),
attr="get",
ctx=ast.Load(),
),
args=[ast.Name(id=nargs_name, ctx=ast.Load())],
keywords=[],
),
),
ast.If(
test=ast.Compare(
left=ast.Constant(None),
ops=[ast.IsNot()],
comparators=[ast.Name(id=arity_name, ctx=ast.Load())],
),
body=[
handle_return(
ast.Call(
func=ast.Name(id=arity_name, ctx=ast.Load()),
args=[
ast.Starred(
value=ast.Name(
id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load()
),
ctx=ast.Load(),
)
],
keywords=[],
)
)
],
orelse=[]
if default_name is None
else [
ast.If(
test=ast.Compare(
left=ast.Name(id=nargs_name, ctx=ast.Load()),
ops=[ast.GtE()],
comparators=[ast.Constant(max_fixed_arity)],
),
body=[
handle_return(
ast.Call(
func=ast.Name(id=default_name, ctx=ast.Load()),
args=[
ast.Starred(
value=ast.Name(
id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load()
),
ctx=ast.Load(),
)
],
keywords=[],
)
)
],
orelse=[],
)
],
),
ast.Raise(
exc=ast.Call(
func=_load_attr("basilisp.lang.runtime.RuntimeException"),
args=[
ast.Constant(f"Wrong number of args passed to function: {name}"),
ast.Name(id=nargs_name, ctx=ast.Load()),
],
keywords=[],
),
cause=None,
),
]

py_fn_node = ast.AsyncFunctionDef if is_async else ast.FunctionDef
meta_deps, meta_decorators = __fn_meta(ctx, meta_node)

ret_ann_ast: Optional[ast.AST] = None
ret_ann_deps: List[ast.AST] = []
if all(tag is not None for tag in return_tags):
Expand All @@ -1945,42 +1865,103 @@ def fn(*args):
else None
)

meta_py_ast: Optional[ast.AST] = None
meta_deps: List[ast.AST] = []
if meta_node is not None:
meta_ast = gen_py_ast(ctx, meta_node)
meta_py_ast = meta_ast.node
meta_deps.extend(meta_ast.dependencies)

return GeneratedPyAST(
node=ast.Name(id=name, ctx=ast.Load()),
dependencies=chain(
[
ast.Assign(
targets=[ast.Name(id=dispatch_map_name, ctx=ast.Store())],
value=ast.Dict(keys=dispatch_keys, values=dispatch_vals),
)
],
meta_deps,
ret_ann_deps,
[
py_fn_node(
name=name,
args=ast.arguments(
posonlyargs=[],
args=[],
kwarg=None,
vararg=ast.arg(arg=_MULTI_ARITY_ARG_NAME, annotation=None),
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=body,
decorator_list=list(
chain(
meta_decorators,
[
__fn_decorator(
arity_map.keys(),
has_rest_arg=default_name is not None,
_tagged_assign(
target=ast.Name(id=name, ctx=ast.Store()),
value=ast.Call(
func=ast.Call(
func=_BASILISP_MULTI_ARITY_FN_FN_NAME,
args=[],
keywords=list(
chain(
[
ast.keyword(
arg="arities",
value=ast.Tuple(
elts=list(
chain(
map(
ast.Constant,
arity_map.keys(),
),
[
ast.Call(
func=_NEW_KW_FN_NAME,
args=[
ast.Constant(
hash(_REST_KW)
),
ast.Constant(
"rest"
),
],
keywords=[],
)
]
if default_name is not None
else [],
)
),
ctx=ast.Load(),
),
),
ast.keyword(
arg="arity_map",
value=ast.Dict(
keys=dispatch_keys, values=dispatch_vals
),
),
],
[
ast.keyword(
arg="max_fixed_arity",
value=ast.Constant(max_fixed_arity),
),
ast.keyword(
arg="default",
value=ast.Name(
id=default_name, ctx=ast.Load()
),
),
]
if default_name is not None
else [],
)
],
)
),
),
args=[
ast.Call(
func=_NEW_SYM_FN_NAME,
args=[ast.Constant(name)],
keywords=[],
),
meta_py_ast
if meta_py_ast is not None
else ast.Constant(None),
],
keywords=[],
),
returns=ret_ann_ast,
annotation=ast.Subscript(
value=ast.Name(id="Callable", ctx=ast.Load()),
slice=ast.Tuple(
elts=[ast.Ellipsis(), ret_ann_ast], ctx=ast.Load()
),
ctx=ast.Load(),
)
if ret_ann_ast is not None
else None,
)
],
),
Expand Down Expand Up @@ -3854,6 +3835,7 @@ def _from_module_imports() -> Iterable[ast.ImportFrom]:
ast.ImportFrom(
module="typing",
names=[
ast.alias(name="Callable", asname=None),
ast.alias(name="Union", asname=None),
],
level=0,
Expand Down
Loading
Loading