From 040fe0c893266596e0a0c31be990bb19c6d4c7f7 Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Thu, 29 Sep 2022 15:18:58 -0400 Subject: [PATCH] Add a benchmark of calling the Dex diagonal convolution from Jax. Add a baseline for it of doing the same thing in Jax (as best I can navigate the convolution docs). Also adjust the continuous.py script to build and install Dex with the FFI library by default, so that both the executable and the Python bindings are accessible. --- .github/workflows/bench.yaml | 2 +- benchmarks/continuous.py | 57 +++++++++++++++++++++++++++++++----- benchmarks/conv.dx | 13 ++++++-- benchmarks/conv_py.py | 36 +++++++++++++++++++++++ makefile | 12 ++++++++ 5 files changed, 109 insertions(+), 11 deletions(-) create mode 100644 benchmarks/conv_py.py diff --git a/.github/workflows/bench.yaml b/.github/workflows/bench.yaml index 32028daef..f986d4099 100644 --- a/.github/workflows/bench.yaml +++ b/.github/workflows/bench.yaml @@ -28,8 +28,8 @@ jobs: - name: Install system dependencies run: | - pip install numpy ${{ matrix.install_deps }} + pip install numpy jax echo "${{ matrix.path_extension }}" >> $GITHUB_PATH - name: Cache diff --git a/benchmarks/continuous.py b/benchmarks/continuous.py index b4ff9d8dc..53a2b0933 100644 --- a/benchmarks/continuous.py +++ b/benchmarks/continuous.py @@ -24,6 +24,7 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" import numpy as np import csv +import jax BASELINE = '8dd1aa8539060a511d0f85779ae2c8019162f567' @@ -95,6 +96,31 @@ def bench(self, _code, _xdg_home): return [Result(self.name, 'runtime', avg_time)] +@dataclass +class PythonSubprocess: + name: str + repeats: int + variant: str = 'latest' + baseline_commit: str = BASELINE # Unused but demanded by the driver + + def clean(self, code, xdg_home): + run(code / 'dex', 'clean', env={'XDG_CACHE_HOME': Path(xdg_home) / self.variant}) + + def bench(self, code, xdg_home): + dex_py_path = code / 'python' + source = code / 'benchmarks' / (self.name + '.py') + env = { # Use the ambient Python so that virtualenv works + 'PATH': os.getenv('PATH'), + # but look for the `dex` package in the installation directory first + 'PYTHONPATH': dex_py_path, + } + runtime = parse_result_runtime(read('python3', source, env=env)) + return [Result(self.name, 'runtime', runtime)] + + def baseline(self): + return Python(self.name, self.repeats, RUNTIME_BASELINES[self.name]) + + def numpy_sum(): n = 1000000 xs = np.arange(n, dtype=np.float64) @@ -122,6 +148,18 @@ def numpy_poly(n): return lambda: np.polynomial.Polynomial([0.0, 1.0, 2.0, 3.0, 4.0])(xs) +def diag_conv_jax(): + # TODO Deduplicate these parameters vs the Dex implementation? + shp = (100, 3, 32, 32) + filter_size = 3 + lhs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32) + rhs = jax.lax.broadcast(jax.numpy.eye(filter_size), (100, 3)) + return lambda: jax.lax.conv_general_dilated( + lhs, rhs, window_strides=(1, 1), padding='SAME', + dimension_numbers=('NCHW', 'OIHW', 'NCHW'), + feature_group_count=1) + + BENCHMARKS = [ DexEndToEnd('kernelregression', 10), DexEndToEnd('psd', 10), @@ -137,6 +175,7 @@ def numpy_poly(n): DexRuntime('poly', 5), DexRuntime('vjp_matmul', 5), DexRuntimeVsDex('conv', 10, baseline_commit='531832c0e18a64c1cab10fc16270b930eed5ed2b'), + PythonSubprocess('conv_py', 5), ] RUNTIME_BASELINES = { 'fused_sum': numpy_sum, @@ -148,12 +187,13 @@ def numpy_poly(n): 'matvec_small': numpy_matvec(10, 10000), 'poly': numpy_poly(100000), 'vjp_matmul': numpy_matmul(500, 1), # TODO: rewrite the baseline in JAX and actually use vjp there + 'conv_py': diag_conv_jax(), } -def run(*args, capture=False, env=None): +def run(*args, capture=False, **kwargs): print('> ' + ' '.join(map(str, args))) - result = subprocess.run(args, text=True, capture_output=capture, env=env) + result = subprocess.run(args, text=True, capture_output=capture, **kwargs) if result.returncode != 0 and capture: line = '-' * 20 print(line, 'CAPTURED STDOUT', line) @@ -177,11 +217,14 @@ def build(commit): if install_path.exists(): print(f'Skipping the build of {commit}') else: - run('git', 'checkout', commit) - run('make', 'install', env=dict(os.environ, PREFIX=commit)) - run('cp', '-r', 'lib', install_path / 'lib') - run('cp', '-r', 'examples', install_path / 'examples') - run('cp', '-r', 'benchmarks', install_path / 'benchmarks') + run('git', 'clone', '.', commit) + run('git', 'checkout', commit, cwd=install_path) + try: + run('make', 'build-ffis-and-exe', cwd=install_path, env=dict(os.environ)) + except subprocess.CalledProcessError: + # Presume that this commit doesn't have the `build-ffis-and-exe` target. + # Then we presumably didn't need the FFI to run anything against it. + run('make', 'install', cwd=install_path, env=dict(os.environ, PREFIX=install_path)) return install_path diff --git a/benchmarks/conv.dx b/benchmarks/conv.dx index f9308840b..51d8ffd6a 100644 --- a/benchmarks/conv.dx +++ b/benchmarks/conv.dx @@ -26,8 +26,13 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float) def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float) (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float = - for n' c'. - conv_1d kernel.n'.c' (unsafe_i_to_n size) + for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size) + +def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float) + (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float = + if size == 3 + then conv kernel 3 + else conv kernel size 'We benchmark it on a roughly representative input. @@ -40,7 +45,9 @@ x1 = for i:(Fin n) m:(Fin width) j:(Fin side) k:(Fin side). :t x1 +filter_size = +3 + %bench "Diagonal convolution" -res = conv x1 3 +res = conv x1 filter_size :t res diff --git a/benchmarks/conv_py.py b/benchmarks/conv_py.py new file mode 100644 index 000000000..fa1682b98 --- /dev/null +++ b/benchmarks/conv_py.py @@ -0,0 +1,36 @@ +import jax +import dex +from dex.interop import jax as djax +import numpy as np + +import time +import timeit + +def bench_python(f, loops=None): + """Return average runtime of `f` in seconds and number of iterations used.""" + if loops is None: + f() + s = time.perf_counter() + f() + e = time.perf_counter() + duration = e - s + loops = max(4, int(2 / duration)) # aim for 2s + return (timeit.timeit(f, number=loops, globals=globals()) / loops, loops) + + +def main(): + with open('benchmarks/conv.dx', 'r') as f: + m = dex.Module(f.read()) + dex_conv = djax.primitive(m.conv_spec) + shp = (int(m.n), int(m.width), int(m.side), int(m.side)) + xs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32) + filter_size = int(m.filter_size) + msg = ("TODO Make dex.interop.primitive return Jax Device Arrays, " + "and change this assert to a block_until_ready() call.") + assert isinstance(dex_conv(xs, filter_size), np.ndarray), msg + time_s, loops = bench_python(lambda : dex_conv(xs, filter_size)) + print(f"> Run time: {time_s} s \t(based on {loops} runs)") + + +if __name__ == '__main__': + main() diff --git a/makefile b/makefile index c234789f1..103a20787 100644 --- a/makefile +++ b/makefile @@ -175,6 +175,18 @@ build-ffis: dexrt-llvm cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/ +# This target is for CI, because it wants to be able to both run the +# `dex` executable and load the `dex` Python package from the same +# directory, without a needless recompile. +build-ffis-and-exe: dexrt-llvm + $(STACK) build $(STACK_FLAGS) --work-dir .stack-work-ffis \ + --flag dex:foreign --flag dex:optimized --force-dirty + $(STACK) install $(STACK_FLAGS) --work-dir .stack-work-ffis \ + --flag dex:foreign --flag dex:optimized --local-bin-path . + $(eval STACK_INSTALL_DIR=$(shell $(STACK) path --work-dir .stack-work-ffis --local-install-root)) + cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ + cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/ + build-ci: dexrt-llvm $(STACK) build $(STACK_FLAGS) --force-dirty --ghc-options "-Werror -fforce-recomp" $(dex) clean # clear cache