Skip to content

Commit

Permalink
Try this
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisrink10 committed Jan 7, 2024
1 parent e2f706a commit 26ab15a
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 127 deletions.
216 changes: 93 additions & 123 deletions src/basilisp/lang/compiler/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,9 @@ def _var_ns_as_python_sym(name: str) -> str:
_COLLECT_ARGS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._collect_args")
_COERCE_SEQ_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}.to_seq")
_BASILISP_FN_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_fn")
_BASILISP_MULTI_ARITY_FN_FN_NAME = _load_attr(
f"{_RUNTIME_ALIAS}._basilisp_multi_arity_fn"
)
_FN_WITH_ATTRS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._with_attrs")
_BASILISP_TYPE_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_type")
_BASILISP_WITH_META_INTERFACE_NAME = _load_attr(f"{_INTERFACES_ALIAS}.IWithMeta")
Expand Down Expand Up @@ -1814,103 +1817,6 @@ def fn(*args):
dispatch_keys.append(ast.Constant(k))
dispatch_vals.append(ast.Name(id=v, ctx=ast.Load()))

# Async functions should return await, otherwise just return
handle_return = __handle_async_return if is_async else __handle_return

nargs_name = genname("nargs")
arity_name = genname("arity")
body = [
ast.Assign(
targets=[ast.Name(id=nargs_name, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="len", ctx=ast.Load()),
args=[ast.Name(id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load())],
keywords=[],
),
),
ast.Assign(
targets=[ast.Name(id=arity_name, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="self", ctx=ast.Load()),
attr="arity_map",
ctx=ast.Load(),
),
attr="get",
ctx=ast.Load(),
),
args=[ast.Name(id=nargs_name, ctx=ast.Load())],
keywords=[],
),
),
ast.If(
test=ast.Compare(
left=ast.Constant(None),
ops=[ast.IsNot()],
comparators=[ast.Name(id=arity_name, ctx=ast.Load())],
),
body=[
handle_return(
ast.Call(
func=ast.Name(id=arity_name, ctx=ast.Load()),
args=[
ast.Starred(
value=ast.Name(
id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load()
),
ctx=ast.Load(),
)
],
keywords=[],
)
)
],
orelse=[]
if default_name is None
else [
ast.If(
test=ast.Compare(
left=ast.Name(id=nargs_name, ctx=ast.Load()),
ops=[ast.GtE()],
comparators=[ast.Constant(max_fixed_arity)],
),
body=[
handle_return(
ast.Call(
func=ast.Name(id=default_name, ctx=ast.Load()),
args=[
ast.Starred(
value=ast.Name(
id=_MULTI_ARITY_ARG_NAME, ctx=ast.Load()
),
ctx=ast.Load(),
)
],
keywords=[],
)
)
],
orelse=[],
)
],
),
ast.Raise(
exc=ast.Call(
func=_load_attr("basilisp.lang.runtime.RuntimeException"),
args=[
ast.Constant(f"Wrong number of args passed to function: {name}"),
ast.Name(id=nargs_name, ctx=ast.Load()),
],
keywords=[],
),
cause=None,
),
]

py_fn_node = ast.AsyncFunctionDef if is_async else ast.FunctionDef
meta_deps, meta_decorators = __fn_meta(ctx, meta_node)

ret_ann_ast: Optional[ast.AST] = None
ret_ann_deps: List[ast.AST] = []
if all(tag is not None for tag in return_tags):
Expand All @@ -1929,40 +1835,103 @@ def fn(*args):
else None
)

meta_py_ast: Optional[ast.AST] = None
meta_deps: List[ast.AST] = []
if meta_node is not None:
meta_ast = gen_py_ast(ctx, meta_node)
meta_py_ast = meta_ast.node
meta_deps.extend(meta_ast.dependencies)

return GeneratedPyAST(
node=ast.Name(id=name, ctx=ast.Load()),
dependencies=chain(
meta_deps,
ret_ann_deps,
[
py_fn_node(
name=name,
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="self", annotation=None)],
kwarg=None,
vararg=ast.arg(arg=_MULTI_ARITY_ARG_NAME, annotation=None),
kwonlyargs=[],
defaults=[],
kw_defaults=[],
),
body=body,
decorator_list=list(
chain(
meta_decorators,
[
__fn_decorator(
arity_map.keys(),
has_rest_arg=default_name is not None,
arity_map=ast.Dict(
keys=dispatch_keys, values=dispatch_vals
),
bind_self_arg=True,
_tagged_assign(
target=ast.Name(id=name, ctx=ast.Store()),
value=ast.Call(
func=ast.Call(
func=_BASILISP_MULTI_ARITY_FN_FN_NAME,
args=[],
keywords=list(
chain(
[
ast.keyword(
arg="arities",
value=ast.Tuple(
elts=list(
chain(
map(
ast.Constant,
arity_map.keys(),
),
[
ast.Call(
func=_NEW_KW_FN_NAME,
args=[
ast.Constant(
hash(_REST_KW)
),
ast.Constant(
"rest"
),
],
keywords=[],
)
]
if default_name is not None
else [],
)
),
ctx=ast.Load(),
),
),
ast.keyword(
arg="arity_map",
value=ast.Dict(
keys=dispatch_keys, values=dispatch_vals
),
),
],
[
ast.keyword(
arg="max_fixed_arity",
value=ast.Constant(max_fixed_arity),
),
ast.keyword(
arg="default",
value=ast.Name(
id=default_name, ctx=ast.Load()
),
),
]
if default_name is not None
else [],
)
],
)
),
),
args=[
ast.Call(
func=_NEW_SYM_FN_NAME,
args=[ast.Constant(name)],
keywords=[],
),
meta_py_ast
if meta_py_ast is not None
else ast.Constant(None),
],
keywords=[],
),
returns=ret_ann_ast,
annotation=ast.Subscript(
value=ast.Name(id="Callable", ctx=ast.Load()),
slice=ast.Tuple(
elts=[ast.Ellipsis(), ret_ann_ast], ctx=ast.Load()
),
ctx=ast.Load(),
)
if ret_ann_ast is not None
else None,
)
],
),
Expand Down Expand Up @@ -3816,6 +3785,7 @@ def _from_module_imports() -> Iterable[ast.ImportFrom]:
ast.ImportFrom(
module="typing",
names=[
ast.alias(name="Callable", asname=None),
ast.alias(name="Union", asname=None),
],
level=0,
Expand Down
57 changes: 53 additions & 4 deletions src/basilisp/lang/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,9 +1777,8 @@ def wrapped_f(*args, **kwargs):


def _basilisp_fn(
arities: Tuple[Union[int, kw.Keyword]],
arities: Tuple[Union[int, kw.Keyword], ...],
arity_map: Optional[Mapping[Union[int, kw.Keyword], Callable]] = None,
bind_self: bool = False,
):
"""Create a Basilisp function, setting meta and supplying a with_meta
method implementation."""
Expand All @@ -1790,14 +1789,64 @@ def wrap_fn(f):
f.arities = lset.set(arities)
f.arity_map = lmap.map(arity_map) if arity_map is not None else None
f.meta = None
if bind_self:
f = partial(f, f)
f.with_meta = partial(_fn_with_meta, f)
return f

return wrap_fn


def _basilisp_multi_arity_fn(
arities: Tuple[Union[int, kw.Keyword], ...],
arity_map: Mapping[Union[int, kw.Keyword], Callable],
max_fixed_arity: Optional[int] = None,
default: Optional[Callable] = None,
) -> Type:
class BasilispMultiArityFn:
__slots__ = ("_name", "arities", "arity_map", "meta")

def __init__(
self,
name: Optional[sym.Symbol],
meta: Optional[lmap.PersistentMap] = None,
):
self._name = name
self.arities = lset.set(arities)
self.arity_map = lmap.map(arity_map)
self.meta = meta

_basilisp_fn = True

if max_fixed_arity is not None:
assert default is not None

def __call__(self, *args):
nargs = len(args)
arity = self.arity_map.get(nargs)
if arity is not None:
return arity(*args)
if nargs > max_fixed_arity:
return default(*args)
raise RuntimeException(
f"Wrong number of args passed to function: {self._name}", nargs
)

else:

def __call__(self, *args):
nargs = len(args)
arity = self.arity_map.get(nargs)
if arity is not None:
return arity(*args)
raise RuntimeException(
f"Wrong number of args passed to function: {self._name}", nargs
)

def with_meta(self, meta: Optional[lmap.PersistentMap] = None):
return type(self)(None, meta)

return BasilispMultiArityFn


def _basilisp_type(
fields: Iterable[str],
interfaces: Iterable[Type],
Expand Down

0 comments on commit 26ab15a

Please sign in to comment.