diff --git a/.github/workflows/bench.yaml b/.github/workflows/bench.yaml index 32028daef..f986d4099 100644 --- a/.github/workflows/bench.yaml +++ b/.github/workflows/bench.yaml @@ -28,8 +28,8 @@ jobs: - name: Install system dependencies run: | - pip install numpy ${{ matrix.install_deps }} + pip install numpy jax echo "${{ matrix.path_extension }}" >> $GITHUB_PATH - name: Cache diff --git a/benchmarks/continuous.py b/benchmarks/continuous.py index b4ff9d8dc..53a2b0933 100644 --- a/benchmarks/continuous.py +++ b/benchmarks/continuous.py @@ -24,6 +24,7 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" import numpy as np import csv +import jax BASELINE = '8dd1aa8539060a511d0f85779ae2c8019162f567' @@ -95,6 +96,31 @@ def bench(self, _code, _xdg_home): return [Result(self.name, 'runtime', avg_time)] +@dataclass +class PythonSubprocess: + name: str + repeats: int + variant: str = 'latest' + baseline_commit: str = BASELINE # Unused but demanded by the driver + + def clean(self, code, xdg_home): + run(code / 'dex', 'clean', env={'XDG_CACHE_HOME': Path(xdg_home) / self.variant}) + + def bench(self, code, xdg_home): + dex_py_path = code / 'python' + source = code / 'benchmarks' / (self.name + '.py') + env = { # Use the ambient Python so that virtualenv works + 'PATH': os.getenv('PATH'), + # but look for the `dex` package in the installation directory first + 'PYTHONPATH': dex_py_path, + } + runtime = parse_result_runtime(read('python3', source, env=env)) + return [Result(self.name, 'runtime', runtime)] + + def baseline(self): + return Python(self.name, self.repeats, RUNTIME_BASELINES[self.name]) + + def numpy_sum(): n = 1000000 xs = np.arange(n, dtype=np.float64) @@ -122,6 +148,18 @@ def numpy_poly(n): return lambda: np.polynomial.Polynomial([0.0, 1.0, 2.0, 3.0, 4.0])(xs) +def diag_conv_jax(): + # TODO Deduplicate these parameters vs the Dex implementation? + shp = (100, 3, 32, 32) + filter_size = 3 + lhs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32) + rhs = jax.lax.broadcast(jax.numpy.eye(filter_size), (100, 3)) + return lambda: jax.lax.conv_general_dilated( + lhs, rhs, window_strides=(1, 1), padding='SAME', + dimension_numbers=('NCHW', 'OIHW', 'NCHW'), + feature_group_count=1) + + BENCHMARKS = [ DexEndToEnd('kernelregression', 10), DexEndToEnd('psd', 10), @@ -137,6 +175,7 @@ def numpy_poly(n): DexRuntime('poly', 5), DexRuntime('vjp_matmul', 5), DexRuntimeVsDex('conv', 10, baseline_commit='531832c0e18a64c1cab10fc16270b930eed5ed2b'), + PythonSubprocess('conv_py', 5), ] RUNTIME_BASELINES = { 'fused_sum': numpy_sum, @@ -148,12 +187,13 @@ def numpy_poly(n): 'matvec_small': numpy_matvec(10, 10000), 'poly': numpy_poly(100000), 'vjp_matmul': numpy_matmul(500, 1), # TODO: rewrite the baseline in JAX and actually use vjp there + 'conv_py': diag_conv_jax(), } -def run(*args, capture=False, env=None): +def run(*args, capture=False, **kwargs): print('> ' + ' '.join(map(str, args))) - result = subprocess.run(args, text=True, capture_output=capture, env=env) + result = subprocess.run(args, text=True, capture_output=capture, **kwargs) if result.returncode != 0 and capture: line = '-' * 20 print(line, 'CAPTURED STDOUT', line) @@ -177,11 +217,14 @@ def build(commit): if install_path.exists(): print(f'Skipping the build of {commit}') else: - run('git', 'checkout', commit) - run('make', 'install', env=dict(os.environ, PREFIX=commit)) - run('cp', '-r', 'lib', install_path / 'lib') - run('cp', '-r', 'examples', install_path / 'examples') - run('cp', '-r', 'benchmarks', install_path / 'benchmarks') + run('git', 'clone', '.', commit) + run('git', 'checkout', commit, cwd=install_path) + try: + run('make', 'build-ffis-and-exe', cwd=install_path, env=dict(os.environ)) + except subprocess.CalledProcessError: + # Presume that this commit doesn't have the `build-ffis-and-exe` target. + # Then we presumably didn't need the FFI to run anything against it. + run('make', 'install', cwd=install_path, env=dict(os.environ, PREFIX=install_path)) return install_path diff --git a/benchmarks/conv.dx b/benchmarks/conv.dx index f9308840b..51d8ffd6a 100644 --- a/benchmarks/conv.dx +++ b/benchmarks/conv.dx @@ -26,8 +26,13 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float) def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float) (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float = - for n' c'. - conv_1d kernel.n'.c' (unsafe_i_to_n size) + for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size) + +def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float) + (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float = + if size == 3 + then conv kernel 3 + else conv kernel size 'We benchmark it on a roughly representative input. @@ -40,7 +45,9 @@ x1 = for i:(Fin n) m:(Fin width) j:(Fin side) k:(Fin side). :t x1 +filter_size = +3 + %bench "Diagonal convolution" -res = conv x1 3 +res = conv x1 filter_size :t res diff --git a/benchmarks/conv_py.py b/benchmarks/conv_py.py new file mode 100644 index 000000000..fa1682b98 --- /dev/null +++ b/benchmarks/conv_py.py @@ -0,0 +1,36 @@ +import jax +import dex +from dex.interop import jax as djax +import numpy as np + +import time +import timeit + +def bench_python(f, loops=None): + """Return average runtime of `f` in seconds and number of iterations used.""" + if loops is None: + f() + s = time.perf_counter() + f() + e = time.perf_counter() + duration = e - s + loops = max(4, int(2 / duration)) # aim for 2s + return (timeit.timeit(f, number=loops, globals=globals()) / loops, loops) + + +def main(): + with open('benchmarks/conv.dx', 'r') as f: + m = dex.Module(f.read()) + dex_conv = djax.primitive(m.conv_spec) + shp = (int(m.n), int(m.width), int(m.side), int(m.side)) + xs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32) + filter_size = int(m.filter_size) + msg = ("TODO Make dex.interop.primitive return Jax Device Arrays, " + "and change this assert to a block_until_ready() call.") + assert isinstance(dex_conv(xs, filter_size), np.ndarray), msg + time_s, loops = bench_python(lambda : dex_conv(xs, filter_size)) + print(f"> Run time: {time_s} s \t(based on {loops} runs)") + + +if __name__ == '__main__': + main() diff --git a/makefile b/makefile index c234789f1..103a20787 100644 --- a/makefile +++ b/makefile @@ -175,6 +175,18 @@ build-ffis: dexrt-llvm cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/ +# This target is for CI, because it wants to be able to both run the +# `dex` executable and load the `dex` Python package from the same +# directory, without a needless recompile. +build-ffis-and-exe: dexrt-llvm + $(STACK) build $(STACK_FLAGS) --work-dir .stack-work-ffis \ + --flag dex:foreign --flag dex:optimized --force-dirty + $(STACK) install $(STACK_FLAGS) --work-dir .stack-work-ffis \ + --flag dex:foreign --flag dex:optimized --local-bin-path . + $(eval STACK_INSTALL_DIR=$(shell $(STACK) path --work-dir .stack-work-ffis --local-install-root)) + cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ + cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/ + build-ci: dexrt-llvm $(STACK) build $(STACK_FLAGS) --force-dirty --ghc-options "-Werror -fforce-recomp" $(dex) clean # clear cache