From e05bd7c0b97019c11d9ee24e7811ae2804c3a8ef Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 13 Feb 2023 18:05:37 +0000 Subject: [PATCH] WIP AD --- jax_triton/pallas/lowering.py | 22 ++++++-- jax_triton/pallas/pallas_call.py | 90 ++++++++++++++++---------------- lib/triton_kernel_call.cc | 20 +++++-- tests/pallas_test.py | 25 +++++++-- 4 files changed, 101 insertions(+), 56 deletions(-) diff --git a/jax_triton/pallas/lowering.py b/jax_triton/pallas/lowering.py index d3125e32..22935d52 100644 --- a/jax_triton/pallas/lowering.py +++ b/jax_triton/pallas/lowering.py @@ -28,6 +28,7 @@ from jax._src.lax.control_flow import for_loop from jax.interpreters import partial_eval as pe from jax.interpreters import xla +from jax._src import ad_checkpoint from jax._src import core as jax_core from jax._src import state from jax._src.state import primitives as sp @@ -572,9 +573,10 @@ def _addupdate_lowering_rule(ctx: TritonLoweringRuleContext, ptr, value, else slc for s, slc in zip(avals_in[0].shape, idx)) idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape) ptr = _offset_ptr(ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder, is_scalar) - old_value = tl.load(ptr, mask=mask, _builder=ctx.builder) - tl.store(ptr, old_value.__add__(value, _builder=ctx.builder), - mask=mask, _builder=ctx.builder) + tl.atomic_add(ptr, value, _builder=ctx.builder) + # old_value = tl.load(ptr, mask=mask, _builder=ctx.builder) + # tl.store(ptr, old_value.__add__(value, _builder=ctx.builder), + # mask=mask, _builder=ctx.builder) return [] triton_lowering_rules[sp.addupdate_p] = _addupdate_lowering_rule @@ -621,9 +623,21 @@ def _reduce_argmin_lowering(ctx: TritonLoweringRuleContext, a, *, axes, triton_lowering_rules[lax.argmin_p] = _reduce_argmin_lowering def _xla_call_lowering_rule(ctx: TritonLoweringRuleContext, *args, call_jaxpr, **_): - return lower_jaxpr_to_triton_ir(ctx.context, call_jaxpr, *args) + return lower_jaxpr_to_triton_ir(ctx.context, call_jaxpr, ctx.block_infos, *args) triton_lowering_rules[xla.xla_call_p] = _xla_call_lowering_rule +def _closed_call_lowering_rule(ctx: TritonLoweringRuleContext, *args, call_jaxpr, **_): + jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts + if consts: + raise NotImplementedError + return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args) +triton_lowering_rules[jax_core.closed_call_p] = _closed_call_lowering_rule + +def _remat_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, **_): + return lower_jaxpr_to_triton_ir(ctx.context, jaxpr, ctx.block_infos, *args) +triton_lowering_rules[ad_checkpoint.remat_p] = _remat_lowering_rule + + def _for_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, which_linear, nsteps, reverse, unroll): current_bb = ctx.builder.get_insert_block() diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index 6bd93994..64a8acd5 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -165,7 +165,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts jvp_which_linear = (*which_linear, *(True,) * len(tangents)) - jvp_inshapes = (*in_shapes, *in_shapes) + _, nonzero_tangent_in_shapes = partition_list(nonzero_tangents, in_shapes) + jvp_inshapes = (*in_shapes, *nonzero_tangent_in_shapes) jvp_outshapes = (*out_shapes, *out_shapes) if input_output_aliases: raise NotImplementedError("`input_output_aliases` jvp not supported.") @@ -181,7 +182,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) - new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) + nonzero_in_bms, _ = partition_list(nonzero_tangents, in_bms) + new_bms = tuple((*in_bms, *nonzero_in_bms, *out_bms, *out_bms)) new_grid_spec = grid_spec.replace(block_mappings=new_bms) jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, *logical_tangent_inputs, @@ -300,12 +302,13 @@ def _pallas_call_partial_eval( jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr, - in_inst=all_unknowns, + in_inst=True, in_unknowns=all_unknowns, ensure_out_unknowns=[], ensure_out_inst=[], saveable=_save_everything) - # # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and + breakpoint() + # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and # regular valued input/outputs. However, we'd like to bind these jaxprs to a # `for`, which expects only `Ref` inputs and no output. We need to convert # both of these jaxprs into ones that are compatible with `for`. @@ -348,13 +351,13 @@ def _pallas_call_partial_eval( for a in res_avals ] res_block_mappings = [ - BlockMapping((*[None] * len(grid), *a.shape), index_map) + BlockMapping((*[pallas_core.mapped] * len(grid), *a.shape), index_map) for a, index_map in zip(res_avals, res_index_mappings) ] known_grid_spec = GridSpec(grid, (*known_in_block_mappings, *known_out_block_mappings, *res_block_mappings), - grid_spec.mapped_dims) + mapped_dims) unknown_grid_spec = GridSpec(grid, (*res_block_mappings, *unknown_in_block_mappings, *unknown_out_block_mappings), @@ -371,7 +374,7 @@ def _pallas_call_partial_eval( input_output_aliases=(), which_linear=tuple(known_which_linear), **compiler_params) - known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)]) + known_outputs, residuals = split_list(known_out_and_res, [len(known_out_shapes)]) residuals = map(trace.new_instantiated_const, residuals) unknown_inputs = [*residuals, *unknown_tracers] unknown_outputs = [ @@ -382,8 +385,7 @@ def _pallas_call_partial_eval( source = source_info_util.current().replace(name_stack=name_stack) unknown_params = dict( jaxpr=jaxpr_unknown, - in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals), - *unknown_in_shapes), + in_shapes=(*res_shapes, *unknown_in_shapes), out_shapes=tuple(unknown_out_shapes), grid_spec=unknown_grid_spec, which_linear=(*res_which_linear, *unknown_which_linear), @@ -399,40 +401,6 @@ def _pallas_call_partial_eval( return merge_lists(out_unknowns, known_outputs, unknown_outputs) pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval -def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool] - ) -> jax_core.Jaxpr: - num_inputs = len(which_linear) - num_outputs = len(jaxpr.invars) - num_inputs - def trans(*args): - # First we want to run the computation to read all the residual refs. We can - # do that by using partial evaluation with all linear inputs unknown. - res_jaxpr, tangent_jaxpr_, *_ = \ - pe.partial_eval_jaxpr_custom(jaxpr, - in_unknowns=[*which_linear, *[True] * - num_outputs], - in_inst=[*which_linear, *[True] * - num_outputs], - ensure_out_inst=[], - ensure_out_unknowns=[], - saveable=_save_everything) - res_args = [x for x, lin in zip(args, which_linear) if not lin] - res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args) - - # Now that we have residual values, we run the tangent jaxpr. It takes as - # input the residuals, and all the refs (at least, the ones - # that are used in the body). Luckily, `tangent_jaxpr_` has all known and - # unknown inputs! - breakpoint() - primals_args = [*(r for u, r in zip(used_res, res) if u)] - ct_args = [x for x, u in zip(args, used_ct) if u] - ad.backward_pass( - tangent_jaxpr, (), False, (), (*res, *ct_args), ()) - breakpoint() - return [] - jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) - return jaxpr_trans - def _pallas_call_transpose_rule(cts_in, *args, jaxpr: jax_core.Jaxpr, name: str, @@ -645,6 +613,12 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, name = compilation_result.name asm = compilation_result.asm shared_mem = compilation_result.shared_mem + ref_effects = state.get_ref_state_effects( + [v.aval for v in jaxpr.invars], jaxpr.effects) + is_accum = [ + all(isinstance(eff, state.AccumEffect) for eff in ref_effect) + for ref_effect in ref_effects + ] if debug: print(jaxpr) print(grid_spec) @@ -664,7 +638,9 @@ def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, # All arguments are buffers. all_args = [None] * (len(in_shapes) + len(out_shapes)) kernel_call = triton_kernel_call_lib.TritonKernelCall( - kernel, grid[0], grid[1], grid[2], all_args + kernel, grid[0], grid[1], grid[2], all_args, + is_accum, + [s.size for s in [*in_shapes, *out_shapes]] ) ctx.module_context.add_keepalive(kernel_call) @@ -735,6 +711,32 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec], return arg_shape return tuple(s for s in block_spec.block_shape if s is not None) +def _pallas_call_bind(*args, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: Tuple[jax.ShapeDtypeStruct, ...], + out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + which_linear: Tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: Tuple[Tuple[int, int], ...], + grid_spec: GridSpec, + **compiler_params: Any): + num_inputs = len(in_shapes) + num_outputs = len(out_shapes) + assert len(jaxpr.invars) == num_inputs + num_outputs, (len(jaxpr.invars), + num_inputs, + num_outputs) + assert len(grid_spec.block_mappings) == len(jaxpr.invars) + return jax_core.Primitive.bind( + pallas_call_p, *args, + jaxpr=jaxpr, name=name, in_shapes=in_shapes, + out_shapes=out_shapes, which_linear=which_linear, + interpret=interpret, debug=debug, + input_output_aliases=input_output_aliases, + grid_spec=grid_spec, **compiler_params) +pallas_call_p.def_custom_bind(_pallas_call_bind) + def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, grid: Optional[Grid] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, diff --git a/lib/triton_kernel_call.cc b/lib/triton_kernel_call.cc index e5483763..f53a4b46 100644 --- a/lib/triton_kernel_call.cc +++ b/lib/triton_kernel_call.cc @@ -145,12 +145,23 @@ class TritonKernelCall : public TritonKernelCallBase { public: TritonKernelCall(TritonKernel& kernel, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2, - std::vector> parameters) + std::vector> parameters, + std::vector zero_out, + std::vector sizes) : kernel_(kernel), grid_{grid_0, grid_1, grid_2}, - parameters_(std::move(parameters)) {} + parameters_(std::move(parameters)), + zero_out_(std::move(zero_out)), + sizes_(sizes) {} void Launch(CUstream stream, void** buffers) override final { + for (int i = 0; i < sizes_.size(); ++i) { + bool should_zero = zero_out_[i]; + if (should_zero) { + uint64_t size = sizes_[i]; + CHECK_CUDA(cuMemsetD8Async((CUdeviceptr) (buffers[i]), 0, size * 4, stream)); + } + } std::vector params; params.reserve(parameters_.size()); for (std::optional& param : parameters_) { @@ -169,6 +180,8 @@ class TritonKernelCall : public TritonKernelCallBase { uint32_t grid_[3]; // Parameter values. `nullopt` values represent buffer arguments. std::vector> parameters_; + std::vector zero_out_; + std::vector sizes_; }; class TritonAutotunedKernelCall : public TritonKernelCallBase { @@ -320,7 +333,8 @@ PYBIND11_MODULE(triton_kernel_call_lib, m) { py::class_(m, "TritonKernelCall") .def(py::init>>(), + std::vector>, std::vector, + std::vector>(), py::keep_alive<1, 2>()) // Ensure that the kernel lives long enough. .def_property_readonly("descriptor", [](TritonKernelCall& kernel_call) { union { diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 0ddcd2d4..4eeee37c 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -534,8 +534,7 @@ class PallasCallAutodifferentiationTest(PallasTest): ("square", lambda x: x * x), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated + ("tanh", jnp.tanh), ]) def test_jvp(self, impl): @functools.partial( @@ -560,8 +559,7 @@ def pallas_impl(x_ref, o_ref): ("square", lambda x: x * x), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated + ("tanh", jnp.tanh), ]) def test_jvp_slice(self, impl): @functools.partial( @@ -584,7 +582,6 @@ def pallas_impl(x_ref, o_ref): rtol=1e-5) jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) - TODO(sharadmv): enable this when we update Triton def test_jvp_matmul(self): k1, k2 = random.split(random.PRNGKey(0)) x = random.normal(k1, (256, 128)) @@ -610,6 +607,24 @@ def add_vectors(x_ref, y_ref, o_ref): out_ref = xy[0] + xy[1] np.testing.assert_allclose(out, out_ref) + @parameterized.named_parameters(*[ + ("square", lambda x: x * x), + ("add_one", lambda x: x + 1.), + ("exp", jnp.exp), + ("tanh", jnp.tanh), + ]) + def test_grad(self, impl): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)) + def pallas_impl(x_ref, o_ref): + o_ref[...] = impl(x_ref[...]) + + x = random.normal(random.PRNGKey(0)) + g = jax.grad(pallas_impl)(x) + g_ref = jax.grad(impl)(x) + np.testing.assert_allclose(g, g_ref, atol=1e-5, rtol=1e-5) + jtu.check_grads(pallas_impl, (x,), modes=["rev"], order=1) + class PallasCallVmapTest(PallasTest):