Skip to content

Commit

Permalink
Add a benchmark of calling the Dex diagonal convolution from Jax.
Browse files Browse the repository at this point in the history
Add a baseline for it of doing the same thing in Jax (as best I can
navigate the convolution docs).

Also adjust the continuous.py script to build and install Dex with the
FFI library by default, so that both the executable and the Python
bindings are accessible.
  • Loading branch information
axch committed Oct 17, 2022
1 parent d75b152 commit 040fe0c
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/bench.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ jobs:

- name: Install system dependencies
run: |
pip install numpy
${{ matrix.install_deps }}
pip install numpy jax
echo "${{ matrix.path_extension }}" >> $GITHUB_PATH
- name: Cache
Expand Down
57 changes: 50 additions & 7 deletions benchmarks/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
os.environ["OPENBLAS_NUM_THREADS"] = "1"
import numpy as np
import csv
import jax


BASELINE = '8dd1aa8539060a511d0f85779ae2c8019162f567'
Expand Down Expand Up @@ -95,6 +96,31 @@ def bench(self, _code, _xdg_home):
return [Result(self.name, 'runtime', avg_time)]


@dataclass
class PythonSubprocess:
name: str
repeats: int
variant: str = 'latest'
baseline_commit: str = BASELINE # Unused but demanded by the driver

def clean(self, code, xdg_home):
run(code / 'dex', 'clean', env={'XDG_CACHE_HOME': Path(xdg_home) / self.variant})

def bench(self, code, xdg_home):
dex_py_path = code / 'python'
source = code / 'benchmarks' / (self.name + '.py')
env = { # Use the ambient Python so that virtualenv works
'PATH': os.getenv('PATH'),
# but look for the `dex` package in the installation directory first
'PYTHONPATH': dex_py_path,
}
runtime = parse_result_runtime(read('python3', source, env=env))
return [Result(self.name, 'runtime', runtime)]

def baseline(self):
return Python(self.name, self.repeats, RUNTIME_BASELINES[self.name])


def numpy_sum():
n = 1000000
xs = np.arange(n, dtype=np.float64)
Expand Down Expand Up @@ -122,6 +148,18 @@ def numpy_poly(n):
return lambda: np.polynomial.Polynomial([0.0, 1.0, 2.0, 3.0, 4.0])(xs)


def diag_conv_jax():
# TODO Deduplicate these parameters vs the Dex implementation?
shp = (100, 3, 32, 32)
filter_size = 3
lhs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32)
rhs = jax.lax.broadcast(jax.numpy.eye(filter_size), (100, 3))
return lambda: jax.lax.conv_general_dilated(
lhs, rhs, window_strides=(1, 1), padding='SAME',
dimension_numbers=('NCHW', 'OIHW', 'NCHW'),
feature_group_count=1)


BENCHMARKS = [
DexEndToEnd('kernelregression', 10),
DexEndToEnd('psd', 10),
Expand All @@ -137,6 +175,7 @@ def numpy_poly(n):
DexRuntime('poly', 5),
DexRuntime('vjp_matmul', 5),
DexRuntimeVsDex('conv', 10, baseline_commit='531832c0e18a64c1cab10fc16270b930eed5ed2b'),
PythonSubprocess('conv_py', 5),
]
RUNTIME_BASELINES = {
'fused_sum': numpy_sum,
Expand All @@ -148,12 +187,13 @@ def numpy_poly(n):
'matvec_small': numpy_matvec(10, 10000),
'poly': numpy_poly(100000),
'vjp_matmul': numpy_matmul(500, 1), # TODO: rewrite the baseline in JAX and actually use vjp there
'conv_py': diag_conv_jax(),
}


def run(*args, capture=False, env=None):
def run(*args, capture=False, **kwargs):
print('> ' + ' '.join(map(str, args)))
result = subprocess.run(args, text=True, capture_output=capture, env=env)
result = subprocess.run(args, text=True, capture_output=capture, **kwargs)
if result.returncode != 0 and capture:
line = '-' * 20
print(line, 'CAPTURED STDOUT', line)
Expand All @@ -177,11 +217,14 @@ def build(commit):
if install_path.exists():
print(f'Skipping the build of {commit}')
else:
run('git', 'checkout', commit)
run('make', 'install', env=dict(os.environ, PREFIX=commit))
run('cp', '-r', 'lib', install_path / 'lib')
run('cp', '-r', 'examples', install_path / 'examples')
run('cp', '-r', 'benchmarks', install_path / 'benchmarks')
run('git', 'clone', '.', commit)
run('git', 'checkout', commit, cwd=install_path)
try:
run('make', 'build-ffis-and-exe', cwd=install_path, env=dict(os.environ))
except subprocess.CalledProcessError:
# Presume that this commit doesn't have the `build-ffis-and-exe` target.
# Then we presumably didn't need the FFI to run anything against it.
run('make', 'install', cwd=install_path, env=dict(os.environ, PREFIX=install_path))
return install_path


Expand Down
13 changes: 10 additions & 3 deletions benchmarks/conv.dx
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)

def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
(size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
for n' c'.
conv_1d kernel.n'.c' (unsafe_i_to_n size)
for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size)

def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
(size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
if size == 3
then conv kernel 3
else conv kernel size

'We benchmark it on a roughly representative input.

Expand All @@ -40,7 +45,9 @@ x1 = for i:(Fin n) m:(Fin width) j:(Fin side) k:(Fin side).

:t x1

filter_size = +3

%bench "Diagonal convolution"
res = conv x1 3
res = conv x1 filter_size

:t res
36 changes: 36 additions & 0 deletions benchmarks/conv_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import jax
import dex
from dex.interop import jax as djax
import numpy as np

import time
import timeit

def bench_python(f, loops=None):
"""Return average runtime of `f` in seconds and number of iterations used."""
if loops is None:
f()
s = time.perf_counter()
f()
e = time.perf_counter()
duration = e - s
loops = max(4, int(2 / duration)) # aim for 2s
return (timeit.timeit(f, number=loops, globals=globals()) / loops, loops)


def main():
with open('benchmarks/conv.dx', 'r') as f:
m = dex.Module(f.read())
dex_conv = djax.primitive(m.conv_spec)
shp = (int(m.n), int(m.width), int(m.side), int(m.side))
xs = jax.random.normal(jax.random.PRNGKey(1), shp, dtype=jax.numpy.float32)
filter_size = int(m.filter_size)
msg = ("TODO Make dex.interop.primitive return Jax Device Arrays, "
"and change this assert to a block_until_ready() call.")
assert isinstance(dex_conv(xs, filter_size), np.ndarray), msg
time_s, loops = bench_python(lambda : dex_conv(xs, filter_size))
print(f"> Run time: {time_s} s \t(based on {loops} runs)")


if __name__ == '__main__':
main()
12 changes: 12 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ build-ffis: dexrt-llvm
cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/
cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/

# This target is for CI, because it wants to be able to both run the
# `dex` executable and load the `dex` Python package from the same
# directory, without a needless recompile.
build-ffis-and-exe: dexrt-llvm
$(STACK) build $(STACK_FLAGS) --work-dir .stack-work-ffis \
--flag dex:foreign --flag dex:optimized --force-dirty
$(STACK) install $(STACK_FLAGS) --work-dir .stack-work-ffis \
--flag dex:foreign --flag dex:optimized --local-bin-path .
$(eval STACK_INSTALL_DIR=$(shell $(STACK) path --work-dir .stack-work-ffis --local-install-root))
cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/
cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/

build-ci: dexrt-llvm
$(STACK) build $(STACK_FLAGS) --force-dirty --ghc-options "-Werror -fforce-recomp"
$(dex) clean # clear cache
Expand Down

0 comments on commit 040fe0c

Please sign in to comment.