Skip to content

Commit

Permalink
PYX generator
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavidberger committed Mar 16, 2022
1 parent 6b35031 commit b8bdeaa
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 19 deletions.
193 changes: 174 additions & 19 deletions cnkalman/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,11 @@ def parse_type(n, type, parent, default):
def get_argument(n, argument_specs, annotation, parent=None, default = None):
if argument_specs is not None and n in argument_specs:
a = argument_specs[n]
digits = math.floor(math.log(a))
if isinstance(a, tuple):
return WrapTuple(n, a)
if isinstance(a, int):
return WrapTuple(n, [ symbols(f"{sanitize_name(n)}{i}") for i in range(a)])
return WrapTuple(n, [ symbols(f"{sanitize_name(n)}{str(i).zfill(digits)}") for i in range(a)])
return a
if annotation is not None:
return parse_type(n, annotation, parent, default)
Expand Down Expand Up @@ -419,6 +420,136 @@ def arg_str(arg):
def generate_args_string(args, as_call = False):
return ", ".join(map(lambda x: get_name(x[1]) if as_call else arg_str, enumerate(args)))

def generate_pyxcode(func, name=None, args=None, suffix = None, argument_specs ={}, outputs = None, preamble = "", file=None, input_keys = None, prefix = ""):
def emit_code(*args, **kwargs):
if file is not None:
print(*args, **kwargs, file=file)

flatten, args = flatten_func(func, name, args, suffix, argument_specs)
if flatten is None:
return None

if outputs is None:
if hasattr(flatten, "shape"):
outputs = [("out", flatten.shape)]
else:
outputs = [("out", -1)]

if callable(func):
name = func.__name__
annotations = inspect.getfullargspec(func).annotations
args = [get_argument(n, argument_specs, annotations.get(n)) for n in inspect.getfullargspec(func).args]

if suffix is not None:
name = name + "_" + suffix

singular_return = len(flatten) == 1

keys = None
free_symbols = set()
def update_free_symbols(v):
if hasattr(v, 'free_symbols'):
free_symbols.update({k.__str__() for k in v.free_symbols})
return

if isinstance(v, Iterable):
for v1 in v:
update_free_symbols(v1)

if isinstance(flatten, dict):
flatten.pop("$original")
keys = list(flatten.keys())
values = [flatten[k] for k in keys]
keys = [k[1] for k in keys]
values = [ a.symengine_type() if hasattr(a, 'symengine_type') else a for a in values ]
cse_output = cse(symengine.Matrix(values))
update_free_symbols(values)
else:
cse_output = cse(symengine.Matrix(flatten))
update_free_symbols(flatten)

if singular_return:
emit_code("def %s%s(%s):" % (prefix, name, ", ".join(map(arg_str, enumerate(args)))))
else:
emit_code("def %s%s(%s):" % (prefix, name, ", ".join(map(lambda x: x.n, args))))

if preamble:
emit_code(preamble.strip("\r\n"))

# Unroll struct types
for idx, a in enumerate(args):
if callable(a):
name = get_name(a)
for k, v in flatten_args(a()):
if f"{name}{k.strip('[]')}" in free_symbols:
emit_code("\tcdef float %s = %s%s" % (str(v), "(*"+name+")" if isinstance_namedtuple(a()) else name, k))
elif isinstance(a, WrapTuple):
name = get_name(a)
digits = math.floor(math.log(len(a.t)))
for k, v in flatten_args(a.t):
idx = k.strip('[]')
if f"{name}{str(idx).zfill(digits)}" in free_symbols:
emit_code("\tcdef float %s = %s%s" % (str(v), name, k))

for item in cse_output[0]:
stripped_line = ccode(item[1]).replace("\n", " ").replace("\t", " ")
emit_code(f"\tcdef float {symengine.ccode(item[0])} = {stripped_line};")

output_idx = 0
outputs_idx = 0

count_zeros = 0
for item_idx, item in enumerate(cse_output[1]):
if item == 0:
count_zeros += 1
needs_set_zero = False#count_zeros > len(cse_output[1]) / 4

if keys is None and not singular_return:
current_shape = outputs[outputs_idx][1] if isinstance(outputs[outputs_idx][1], tuple) else [outputs[outputs_idx][1], 1]
var = outputs[outputs_idx][0]

emit_code(f"\tcdef np.ndarray[float, ndim=2] {outputs[outputs_idx][0]} = np.zeros(({current_shape[0]},{current_shape[1]}), dtype=np.float32)")
for item_idx, item in enumerate(cse_output[1]):
if keys is None:
current_shape = outputs[outputs_idx][1] if isinstance(outputs[outputs_idx][1], tuple) else [outputs[outputs_idx][1], 1]
current_row = output_idx // current_shape[1]
current_col = output_idx % current_shape[1]

def get_col_str():
if len(outputs[outputs_idx]) > 2 and hasattr(outputs[outputs_idx][2][current_col], 'offsetof'):
offset_of = outputs[outputs_idx][2][current_col].offsetof()
if offset_of is not None:
return offset_of
return str(current_col)
def get_row_str():
if input_keys is not None:
root, path = input_keys[current_row]
return f"offsetof({root}, {path})/sizeof(FLT)"
return str(current_row)
if hasattr(item, "tolist"):
for item1 in sum(item.tolist(), []):
emit_code("\t%s[%s,%s] = %s" % (outputs[outputs_idx][0], get_row_str(), get_col_str(), output_idx, ccode(item1).replace("\n", " ").replace("\t", " ")))
output_idx += 1
current_row = output_idx / current_shape[1]
current_col = output_idx % current_shape[1]
else:
if singular_return:
emit_code("\treturn %s;" % (ccode(item).replace("\n", " ").replace("\t", " ")))
else:
if item != 0 or not needs_set_zero:
emit_code("\t%s[%s,%s] = %s" % (outputs[outputs_idx][0], get_row_str(), get_col_str(), ccode(item).replace("\n", " ").replace("\t", " ")))
output_idx += 1
if output_idx >= math.prod(current_shape) > 0:
emit_code(f"\treturn {outputs[outputs_idx][0]}")
outputs_idx += 1
output_idx = 0
else:
nl = "\n"
emit_code(f"\tout->{keys[item_idx]}={ccode(item).replace(nl, '')};")

emit_code("")
return flatten

def generate_ccode(func, name=None, args=None, suffix = None, argument_specs ={}, outputs = None, preamble = "", file=None, input_keys = None, prefix = ""):
def emit_code(*args, **kwargs):
if file is not None:
Expand Down Expand Up @@ -581,7 +712,7 @@ def flat_values(a):
return [a]


def generate_jacobians(func, suffix=None,transpose=False,jac_all=False, jac_over=None, argument_specs={}, file=None, prefix=""):
def generate_jacobians(func, suffix=None,transpose=False,jac_all=False, jac_over=None, argument_specs={}, file=None, prefix="", codegen=generate_ccode):
def emit_code(*args, **kwargs):
if file is not None:
print(*args, **kwargs, file=file)
Expand Down Expand Up @@ -627,12 +758,12 @@ def emit_code(*args, **kwargs):
if jac_size == 1:
continue

emit_code("// Jacobian of", func.__name__, "wrt", jac_value)
generate_ccode(this_jac, fname, func_args, suffix=suffix, outputs=[('Hx', jac_shape, jac_value)], input_keys=keys,file=file, prefix=prefix)
#emit_code("// Jacobian of", func.__name__, "wrt", jac_value)
codegen(this_jac, fname, func_args, suffix=suffix, outputs=[('Hx', jac_shape, jac_value)], input_keys=keys,file=file, prefix=prefix)

#jac_with_hx = this_jac.reshape(jac_size, 1).col_join(fxm.reshape(fx_size, 1))

emit_code("// Full version Jacobian of", func.__name__, "wrt", jac_value)
#emit_code("// Full version Jacobian of", func.__name__, "wrt", jac_value)

fn_suffix = ""
if suffix is not None:
Expand All @@ -656,15 +787,16 @@ def emit_code(*args, **kwargs):
# }}"""
# generate_ccode(jac_with_hx, fname + "_with_hx", func_args, suffix=suffix, outputs=[('Hx', jac_shape, jac_value), ('hx', this_jac.shape[0] - jac_size)], preamble=preamble, file)
outputs = [('Hx', jac_shape, jac_value), ('hx', this_jac.shape[0] - jac_size)]
emit_code(f"""
static inline void {prefix}{fname}_with_hx({", ".join(["CnMat* " + s[0] for s in outputs])}, {", ".join(map(arg_str, enumerate(func_args)))}) {{
if(hx != 0) {{
{gen_call}
}}
if(Hx != 0) {{
{prefix}{fname}{fn_suffix}(Hx, {generate_args_string(func_args, True)});
}}
}}""")
if codegen == generate_ccode:
emit_code(f"""
static inline void {prefix}{fname}_with_hx({", ".join(["CnMat* " + s[0] for s in outputs])}, {", ".join(map(arg_str, enumerate(func_args)))}) {{
if(hx != 0) {{
{gen_call}
}}
if(Hx != 0) {{
{prefix}{fname}{fn_suffix}(Hx, {generate_args_string(func_args, True)});
}}
}}""")

rtn['jacobian_of_' + name] = this_jac.reshape(*jac_shape)
return rtn, func_args
Expand All @@ -678,18 +810,41 @@ def can_generate_jacobian(f):
return all(map(can_generate_jacobian, f))
return True

def generate_code_and_jacobians(f,transpose=False, jac_over=None, argument_specs = {}, file=None, prefix=""):
f_eval = generate_ccode(f, argument_specs = argument_specs, file=file, prefix=prefix)
def generate_code_and_jacobians(f,transpose=False, jac_over=None, argument_specs = {}, file=None, prefix="", codegen=generate_ccode):
f_eval = codegen(f, argument_specs = argument_specs, file=file, prefix=prefix)
if can_generate_jacobian(f_eval):
return generate_jacobians(f, argument_specs = argument_specs, transpose=transpose, jac_over=jac_over, file=file, prefix=prefix)
return generate_jacobians(f, argument_specs = argument_specs, transpose=transpose, jac_over=jac_over, file=file, prefix=prefix, codegen=codegen)
return None, None

from pathlib import Path

generate_code_files = {}
def get_pyx_file(fn):
if not '--cnkalman-generate-source-pyx' in sys.argv:
return None
if fn != sys.argv[0]:
return None
if fn in generate_code_files:
return generate_code_files[fn]
path = Path(fn)
print(f"Generating {path.parent.as_posix()}/{path.stem}_gen.pyx...", file=sys.stderr)
f = generate_code_files[fn] = open(f"{path.parent.as_posix()}/{path.stem}_gen.pyx", 'w')
f.write(
"""# NOTE: This is a generated file; do not edit.
# clang-format off
import cython
import numpy as np
cimport numpy as np
from libc.math cimport *
from libc.stdint cimport *
from libc cimport *
""")
return generate_code_files[fn]

def get_file(fn):
if not '--cnkalman-generate-source' in sys.argv:
return None
return get_pyx_file(fn)
if fn != sys.argv[0]:
return None
if fn in generate_code_files:
Expand Down Expand Up @@ -745,7 +900,7 @@ def g(*args):
return np.array(grtn, dtype=np.float64)
return grtn

jacs, args = generate_code_and_jacobians(func, argument_specs=kwargs, file=f, prefix=prefix)
jacs, args = generate_code_and_jacobians(func, argument_specs=kwargs, file=f, prefix=prefix, codegen= generate_pyxcode if f is not None and f.name.endswith(".pyx") else generate_ccode)
if jacs is not None:
for k, v in jacs.items():
setattr(g, k, functionify(args, v))
Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[dependencies]
opencv-contrib-python = "^4.5.1.48"
scipy="^1.7.0"
numpy="^1.19.5"
opencv-python="^4.5.5.62"

[build-system]
requires = [
"setuptools>=42",
"wheel",
"setuptools-git-versioning",
]
build-backend = "setuptools.build_meta"

[tool.setuptools-git-versioning]
enabled = true
23 changes: 23 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[metadata]
name = cnkalman
version = 0.0.1
author = Justin Berger
author_email = [email protected]
description = Support tools for cnkalman
long_description = file: README.md
long_description_content_type = text/markdown
classifiers =
Programming Language :: Python :: 3
Operating System :: OS Independent

[options]
package_dir =
= src
packages = find:
python_requires = >=3.6
install_requires =
sympy
symengine

[options.packages.find]
where = .

0 comments on commit b8bdeaa

Please sign in to comment.