Skip to content

Commit

Permalink
add CogView3PlusTransformer2DModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Nov 30, 2024
1 parent 18fc782 commit 943dea6
Show file tree
Hide file tree
Showing 3 changed files with 500 additions and 0 deletions.
85 changes: 85 additions & 0 deletions mindone/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,58 @@ def construct(self, text_embeds: ms.Tensor, image_embeds: ms.Tensor):
return embeds


class CogView3PlusPatchEmbed(nn.Cell):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
pos_embed_max_size: int = 128,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
# Linear projection for image patches
self.proj = nn.Dense(in_channels * patch_size**2, hidden_size)

# Linear projection for text embeddings
self.text_proj = nn.Dense(text_hidden_size, hidden_size)

pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.pos_embed = ms.Tensor.from_numpy(pos_embed).float()

def construct(self, hidden_states: ms.Tensor, encoder_hidden_states: ms.Tensor) -> ms.Tensor:
batch_size, channel, height, width = hidden_states.shape

if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")

height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)

# Project the patches
hidden_states = self.proj(hidden_states)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
hidden_states = ops.cat([encoder_hidden_states, hidden_states], axis=1)

# Calculate text_length
text_length = encoder_hidden_states.shape[1]

image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
text_pos_embed = ops.zeros((text_length, self.hidden_size), dtype=image_pos_embed.dtype)
pos_embed = ops.cat([text_pos_embed, image_pos_embed], axis=0)[None, ...]

return (hidden_states + pos_embed).to(hidden_states.dtype)


def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[ms.Tensor, Tuple[ms.Tensor, ms.Tensor]]:
Expand Down Expand Up @@ -1087,6 +1139,39 @@ def construct(self, timestep, guidance, pooled_projection):
return conditioning


class CogView3CombinedTimestepSizeEmbeddings(nn.Cell):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()

self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def construct(
self,
timestep: ms.Tensor,
original_size: ms.Tensor,
target_size: ms.Tensor,
crop_coords: ms.Tensor,
hidden_dtype: ms.Type,
) -> ms.Tensor:
timesteps_proj = self.time_proj(timestep)

original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.shape[0], -1)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.shape[0], -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.shape[0], -1)

# (B, 3 * condition_dim)
condition_proj = ops.cat([original_size_proj, crop_coords_proj, target_size_proj], axis=1)

timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)

conditioning = timesteps_emb + condition_emb
return conditioning


class HunyuanDiTAttentionPool(nn.Cell):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6

Expand Down
45 changes: 45 additions & 0 deletions mindone/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,51 @@ def construct(
return x


class CogView3PlusAdaLayerNormZeroTextImage(nn.Cell):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, dim: int):
super().__init__()

self.silu = nn.SiLU()
self.linear = nn.Dense(embedding_dim, 12 * dim, has_bias=True)
self.norm_x = LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_c = LayerNorm(dim, elementwise_affine=False, eps=1e-5)

def construct(
self,
x: ms.Tensor,
context: ms.Tensor,
emb: Optional[ms.Tensor] = None,
) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]:
emb = self.linear(self.silu(emb))
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
c_shift_msa,
c_scale_msa,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = emb.chunk(12, axis=1)
normed_x = self.norm_x(x)
normed_context = self.norm_c(context)
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp


class CogVideoXLayerNormZero(nn.Cell):
def __init__(
self,
Expand Down
Loading

0 comments on commit 943dea6

Please sign in to comment.