Skip to content

Commit

Permalink
Integrate Triton up to [9f816a7b](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
gflegar authored and The jax_triton Authors committed Jan 30, 2024
1 parent 4af0ecb commit 28ad476
Show file tree
Hide file tree
Showing 26 changed files with 30 additions and 58 deletions.
2 changes: 1 addition & 1 deletion examples/add.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/block_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/fused_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/fusion/benchmark_matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/fusion/nn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/pallas/blocksparse_matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/pallas/fused_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/pallas/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/pallas/lstm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/pallas/templating.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion examples/softmax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/fusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/jaxpr_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/lowering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/pallas/mosaic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/pallas/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/pallas/triton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
38 changes: 5 additions & 33 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,8 +48,7 @@
from triton.runtime import autotuner
import triton._C.libtriton as _triton
from triton._C.libtriton import ir as tl_ir
from triton.common.backend import get_backend
import triton.compiler.backends.cuda as cb
import triton.backends.nvidia.compiler as cb

CAN_USE_TRITON = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -154,33 +153,6 @@ def aval_size_bytes(aval):
return np.dtype(aval.dtype).itemsize * aval.size


def ptx_get_kernel_name(module) -> str:
return cb.get_kernel_name(module, pattern="// .globl")


def get_arch_default_num_warps(device_type):
if device_type in ["cuda", "hip"]:
num_warps = 4
else:
device_backend = get_backend(device_type)
assert device_backend
arch = device_backend.get_architecture_descriptor()
num_warps = arch["num_warps"]
return num_warps


def get_arch_default_num_stages(device_type, capability):
if device_type == "cuda":
num_stages = 3 if capability >= 75 else 2
else:
device_backend = get_backend(device_type)
assert device_backend
arch = device_backend.get_architecture_descriptor()
num_stages = arch["num_stages"]

return num_stages


def compile_ttir_to_ptx_inplace(
ttir,
tl_context: tl_ir.Context,
Expand Down Expand Up @@ -236,7 +208,7 @@ def compile_ttir_to_ptx_inplace(
)
if cuda_options.debug:
print(ptx)
name = ptx_get_kernel_name(ptx)
name = metadata["name"]
cluster_dims = metadata["cluster_dims"]
return ptx, name, shared_mem_bytes, compute_capability, cluster_dims

Expand All @@ -260,14 +232,14 @@ def get_or_create_triton_kernel(
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
device_type = "cuda"
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
num_warps = 4
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type, arch)
num_stages = 3

signature = dict(enumerate(arg_dtypes))
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion tests/triton_call_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion tests/triton_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The jax_triton Authors.
# Copyright 2024 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 28ad476

Please sign in to comment.