Skip to content

Commit

Permalink
Make Flux Transformer RoPE use a custom IREE kernel
Browse files Browse the repository at this point in the history
We assume that the custom kernel would yield better performance instead
of using PyTorch ops.
  • Loading branch information
sogartar committed Jan 28, 2025
1 parent 3cecd77 commit 4ff5c3e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
16 changes: 9 additions & 7 deletions sharktank/sharktank/layers/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor

from .. import ops
from .. import kernels

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -25,17 +26,18 @@ def qk_norm(q, k, v, rms_q, rms_k):
return rms_q(q).to(v), rms_k(k).to(v)


# TODO: Work on unifying with the current RoPE layer
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
xq2 = xq.permute(0, 2, 1, 3).to(freqs_cis.dtype)
xk2 = xk.permute(0, 2, 1, 3).to(freqs_cis.dtype)
xq_out = kernels.apply_rotary_embedding(xq2.to(freqs_cis.dtype), freqs_cis)
xk_out = kernels.apply_rotary_embedding(xk2.to(freqs_cis.dtype), freqs_cis)
xq_out = xq_out.permute(0, 2, 1, 3)
xk_out = xk_out.permute(0, 2, 1, 3)
return xq_out.type_as(xq), xk_out.type_as(xk)


def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe) # todo
q, k = apply_rope(q, k, pe)

x = ops.scaled_dot_product_attention(q=q, k=k, v=v, a=None)
x = ops.permute(x, (0, 2, 1, 3))
Expand Down
21 changes: 16 additions & 5 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,9 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
return xt_out.type_as(xt)

def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
freqs = 1.0 / (
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
return compute_rotary_embedding_table(
t, self.rope_dimension_count, self.rope_freq_base, torch.float32
)
freqs = torch.outer(t, freqs).float()
return freqs

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand All @@ -238,3 +235,17 @@ def _replicate(self, t):
t = ops.replicate(t, self.tensor_parallelism_size)

return t


def compute_rotary_embedding_table(
positions: torch.Tensor,
rope_dimension_count: int,
rope_freq_base: float,
dtype: torch.dtype,
) -> torch.Tensor:
dim = rope_dimension_count
freqs = 1.0 / (
rope_freq_base ** ((torch.arange(0, dim) // 2).to(dtype=dtype) / dim * 2.0)
)
freqs = torch.outer(positions, freqs)
return freqs
22 changes: 10 additions & 12 deletions sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn.functional as F

from ...layers import *
from ...layers.rotary_embedding import compute_rotary_embedding_table
from ...types import *
from ...utils.create_cache import *
from ...utils.testing import make_rand_torch
Expand Down Expand Up @@ -318,16 +319,13 @@ def qk_norm(q, k, v, rms_q, rms_k):


def rope(pos: AnyTensor, dim: int, theta: int) -> AnyTensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack(
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
)
# out = out.view(out.shape[0], out.shape[1], out.shape[2], out.shape[3], 2, 2)
out = out.view(out.shape[0], out.shape[1], out.shape[2], 2, 2)
return out.float()
out = compute_rotary_embedding_table(
positions=pos.flatten(),
rope_dimension_count=dim,
rope_freq_base=theta,
dtype=torch.float64,
).float()
return out.view(list(pos.shape) + [out.shape[-1]])


class MLPEmbedder(ThetaLayer):
Expand All @@ -353,10 +351,10 @@ def forward(self, ids: AnyTensor) -> AnyTensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
dim=2,
)

return emb.unsqueeze(1)
return emb


class LastLayer(ThetaLayer):
Expand Down

0 comments on commit 4ff5c3e

Please sign in to comment.