diff --git a/tests/diffusion_labs/test_ldm.py b/tests/diffusion_labs/test_ldm.py new file mode 100644 index 00000000..660d6032 --- /dev/null +++ b/tests/diffusion_labs/test_ldm.py @@ -0,0 +1,412 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from functools import partial +from typing import Dict, List, Optional, Sequence, Set + +import pytest +import torch +from tests.test_utils import assert_expected, set_rng_seed +from torch import nn, Tensor +from torchmultimodal.diffusion_labs.models.ldm.ldm import LDMModel, LDMUNet + + +@pytest.fixture(autouse=True) +def set_seed(): + set_rng_seed(54321) + + +@pytest.fixture +def in_channels() -> int: + return 5 + + +@pytest.fixture +def out_channels() -> int: + return 6 + + +@pytest.fixture +def model_channels() -> int: + return 32 + + +@pytest.fixture +def channel_multipliers() -> List[int]: + return [2, 4] + + +@pytest.fixture +def attention_resolutions() -> Set[int]: + return {2} + + +@pytest.fixture +def context_dim() -> int: + return 32 + + +@pytest.fixture +def context_dims(context_dim) -> List[int]: + return [context_dim] + + +@pytest.fixture +def num_res_blocks() -> int: + return 2 + + +@pytest.fixture +def num_res_blocks_list() -> List[int]: + return [2, 2] + + +@pytest.fixture +def num_attn_heads() -> int: + return 4 + + +@pytest.fixture +def num_channels_per_head() -> int: + return 32 + + +@pytest.fixture +def image_size() -> int: + return 8 + + +@pytest.fixture +def bsize() -> int: + return 5 + + +@pytest.fixture +def context_seqlen() -> int: + return 5 + + +@pytest.fixture +def max_time() -> int: + return 20 + + +@pytest.fixture +def coordinate_embedding_dim() -> int: + return 10 + + +@pytest.fixture +def embed_input_size() -> bool: + return True + + +@pytest.fixture +def embed_target_size() -> bool: + return True + + +@pytest.fixture +def embed_crop_params() -> bool: + return True + + +@pytest.fixture +def num_coordinates(embed_input_size, embed_target_size, embed_crop_params) -> int: + return 2 * (embed_input_size + embed_target_size + embed_crop_params) + + +@pytest.fixture +def pooled_text_embedding_dim() -> int: + return 20 + + +@pytest.fixture +def x(bsize, in_channels, image_size) -> Tensor: + return torch.randn(bsize, in_channels, image_size, image_size) + + +@pytest.fixture +def context(bsize, context_seqlen, context_dim) -> List[Tensor]: + return [torch.randn(bsize, context_seqlen, context_dim)] + + +@pytest.fixture +def additional_embeddings( + bsize, pooled_text_embedding_dim, num_coordinates +) -> Dict[str, Tensor]: + return { + "pooled_text_embed": torch.randn(bsize, pooled_text_embedding_dim), + "coordinates": torch.randn(bsize, num_coordinates), + } + + +@pytest.fixture +def time(bsize, max_time) -> Tensor: + return torch.randint(1, max_time, (bsize,)) + + +@pytest.fixture +def context_keys() -> List[str]: + return ["a", "b", "c", "d", "e"] + + +@pytest.fixture +def context_dict(bsize, context_seqlen, context_dim, context_keys) -> Dict[str, Tensor]: + return {k: torch.randn(bsize, context_seqlen, context_dim) for k in context_keys} + + +# All expected values come after first testing that LDMUNet has the exact +# output as the corresponding class in d2go, then simply forward passing +# LDMUNet with params, random seed, and initialization order in this file. +class TestLDMUNet: + @pytest.fixture + def unet( + self, + in_channels, + out_channels, + model_channels, + channel_multipliers, + attention_resolutions, + context_dims, + num_res_blocks, + num_attn_heads, + ): + return partial( + LDMUNet, + in_channels=in_channels, + out_channels=out_channels, + model_channels=model_channels, + context_dims=context_dims, + attention_resolutions=attention_resolutions, + channel_multipliers=channel_multipliers, + num_res_blocks_per_level=num_res_blocks, + num_attention_heads=num_attn_heads, + ) + + def _unzero_unet_params(self, unet: LDMUNet): + for p in unet.parameters(): + if torch.allclose(p, torch.zeros_like(p)): + nn.init.normal_(p) + + def test_forward(self, unet, x, time, context, out_channels): + unet_module = unet() + self._unzero_unet_params(unet_module) + expected_shape = torch.Size( + [x.size()[0], out_channels, x.size()[2], x.size()[3]] + ) + expected = torch.tensor(-4.60389) + actual = unet_module(x, time, context) + assert_expected(actual.size(), expected_shape) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-4) + + def test_forward_no_context(self, unet, x, time, context): + unet_module = unet(context_dims=None) + self._unzero_unet_params(unet_module) + expected = torch.tensor(-1.93169) + actual = unet_module(x, time) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-3) + actual = unet_module(x, time, context) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-3) + + def test_forward_additional_embeddings( + self, + unet, + x, + time, + context, + out_channels, + coordinate_embedding_dim, + embed_input_size, + embed_target_size, + embed_crop_params, + pooled_text_embedding_dim, + additional_embeddings, + ): + unet_module = unet( + coordinate_embedding_dim=coordinate_embedding_dim, + embed_input_size=embed_input_size, + embed_target_size=embed_target_size, + embed_crop_params=embed_crop_params, + pooled_text_embedding_dim=pooled_text_embedding_dim, + ) + self._unzero_unet_params(unet_module) + expected_shape = torch.Size( + [x.size()[0], out_channels, x.size()[2], x.size()[3]] + ) + expected = torch.tensor(2.99702) + actual = unet_module( + x, time, context, additional_embeddings=additional_embeddings + ) + assert_expected(actual.size(), expected_shape) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-4) + + def test_forward_num_res_blocks_list( + self, unet, num_res_blocks_list, x, time, context + ): + unet_module = unet(num_res_blocks_per_level=num_res_blocks_list) + self._unzero_unet_params(unet_module) + expected = torch.tensor(-4.60389) + actual = unet_module(x, time, context) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-3) + + def test_forward_res_updown(self, unet, x, time, context): + unet_module = unet(use_res_block_updown=True) + self._unzero_unet_params(unet_module) + expected = torch.tensor(-5.13025) + actual = unet_module(x, time, context) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-4) + + def test_forward_linear_projection(self, unet, x, time, context): + unet_module = unet(use_linear_projection_in_transformer=True) + self._unzero_unet_params(unet_module) + expected = torch.tensor(-4.60389) + actual = unet_module(x, time, context) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-3) + + def test_forward_scale_shift_conditional(self, unet, x, time, context): + unet_module = unet(scale_shift_conditional=True) + self._unzero_unet_params(unet_module) + expected = torch.tensor(-0.12192) + actual = unet_module(x, time, context) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-4) + + def test_forward_heterogeneous_depths(self, unet, x, time, context): + unet_module = unet(num_transformer_layers=[1, 2]) + self._unzero_unet_params(unet_module) + expected = torch.tensor(3.86524) + actual = unet_module(x, time, context) + assert_expected(actual.mean(), expected, rtol=0, atol=1e-4) + + def test_unet_num_res_blocks_channels_mismatch_error(self, unet): + with pytest.raises(ValueError): + _ = unet(num_res_blocks_per_level=[1, 2, 3]) + + def test_unet_norm_group_error(self, unet): + with pytest.raises(ValueError): + _ = unet(model_channels=17) + + def test_unet_context_dims_transformer_layers_mismatch_error( + self, unet, context_dim + ): + with pytest.raises(ValueError): + _ = unet(context_dims=[context_dim] * 2) + + with pytest.raises(ValueError): + _ = unet(num_transformer_layers=[1, 1, 1]) + + def test_unet_context_list_len_error(self, unet, x, time, context): + unet_module = unet() + with pytest.raises(RuntimeError): + unet_module(x, time, context + deepcopy(context)) + + def test_unet_context_dim_mismatch(self, unet, x, time, context): + unet_module = unet() + with pytest.raises(RuntimeError): + unet_module(x, time, torch.cat([context[0], context[0]], dim=-1)) + + def test_unet_num_heads_channels_errors( + self, + unet, + num_attn_heads, + num_channels_per_head, + ): + # Test when both num_attn_heads and num_channels_per_head are None + with pytest.raises(ValueError): + _ = unet(num_attention_heads=None, num_channels_per_attention_head=None) + + with pytest.raises(ValueError): + _ = unet( + num_attention_heads=num_attn_heads, + num_channels_per_attention_head=num_channels_per_head, + ) + + +class TestLDMModel: + @pytest.fixture + def unet(self): + class SimpleUNet(nn.Module): + def forward( + self, + x: Tensor, + t: Tensor, + context_list: Optional[Sequence[Tensor]] = None, + additional_embeddings: Optional[Dict[str, Tensor]] = None, + ): + if additional_embeddings is not None: + return torch.stack( + [t.mean() for t in additional_embeddings.values()], dim=0 + ).sum() + + if isinstance(context_list, Sequence) and len(context_list) > 0: + return torch.cat(context_list, dim=1) + else: + return torch.zeros(1) + + return SimpleUNet() + + @pytest.fixture + def model(self, unet): + return partial(LDMModel, unet=unet) + + @pytest.mark.parametrize( + "cond_keys,expected_value", + [([], 0.0), (["a", "b"], 26.7966), (["a", "b", "c"], 48.4302)], + ) + def test_forward(self, model, x, time, context_dict, cond_keys, expected_value): + ldm_model = model(cond_keys=cond_keys) + expected = torch.tensor(expected_value) + actual = ldm_model(x, time, context_dict) + assert_expected(actual.prediction.sum(), expected, rtol=0, atol=1e-4) + + @pytest.mark.parametrize( + "cond_keys,expected_value", + [(["a"], 45.04681), (["b"], -18.25021)], + ) + def test_forward_single_context( + self, model, x, time, context_dict, cond_keys, expected_value + ): + ldm_model = model(cond_keys=cond_keys) + expected = torch.tensor(expected_value) + actual = ldm_model(x, time, context_dict) + assert_expected(actual.prediction.sum(), expected, rtol=0, atol=1e-4) + + @pytest.mark.parametrize( + "additional_cond_keys,expected_value", + [ + (["pooled_text_embed"], -0.09103), + (["pooled_text_embed", "coordinates"], -0.1936), + ], + ) + def test_forward_additional_embeddings( + self, + model, + x, + time, + context_dict, + additional_cond_keys, + additional_embeddings, + expected_value, + ): + ldm_model = model(additional_cond_keys=additional_cond_keys) + expected = torch.tensor(expected_value) + context_dict.update(additional_embeddings) + actual = ldm_model(x, time, context_dict) + assert_expected(actual.prediction.sum(), expected, rtol=0, atol=1e-4) + + def test_forward_context_dim_error( + self, model, x, time, bsize, context_seqlen, context_dim + ): + context_dict = { + "a": torch.randn(bsize, context_seqlen, context_dim), + "b": torch.randn(bsize, context_dim), + } + with pytest.raises(RuntimeError): + model(cond_keys=["a", "b"])(x, time, context_dict) + + # Should not raise runtime error because 'b' is not a cond key + model(cond_keys=["a"])(x, time, context_dict) diff --git a/tests/diffusion_labs/test_ldm_spatial_transformer.py b/tests/diffusion_labs/test_ldm_spatial_transformer.py new file mode 100644 index 00000000..55dcc85e --- /dev/null +++ b/tests/diffusion_labs/test_ldm_spatial_transformer.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from functools import partial + +import pytest +import torch +from tests.test_utils import assert_expected, set_rng_seed +from torch import nn +from torchmultimodal.diffusion_labs.models.ldm.spatial_transformer import ( + SpatialTransformer, + SpatialTransformerCrossAttentionLayer, +) + + +@pytest.fixture(autouse=True) +def set_seed(): + set_rng_seed(54321) + + +@pytest.fixture +def in_channels(): + return 16 + + +@pytest.fixture +def num_heads(): + return 2 + + +@pytest.fixture +def num_layers(): + return 3 + + +@pytest.fixture +def context_dim(): + return 8 + + +@pytest.fixture +def batch_size(): + return 3 + + +@pytest.fixture +def x(batch_size, in_channels): + return torch.randn(batch_size, 10, in_channels) + + +@pytest.fixture +def x_img(batch_size, in_channels): + return torch.randn(batch_size, in_channels, 8, 8) + + +@pytest.fixture +def context(batch_size, context_dim): + return torch.randn(batch_size, 6, context_dim) + + +# All expected values come after first testing that SpatialTransformerCrossAttentionLayer +# has the exact output as the corresponding class in d2go, then simply +# forward passing SpatialTransformerCrossAttentionLayer with params, random seed, and +# initialization order in this file. +class TestSpatialTransformerCrossAttentionLayer: + @pytest.fixture + def attn(self, in_channels, num_heads): + return partial( + SpatialTransformerCrossAttentionLayer, + d_model=in_channels, + num_heads=num_heads, + ) + + def test_cross_attn_forward_with_context(self, attn, x, context_dim, context): + attn_module = attn(context_dim=context_dim) + actual = attn_module(x, context) + expected = torch.tensor(46.95579) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + def test_cross_attn_forward_without_context(self, attn, in_channels, x): + attn_module = attn(context_dim=in_channels) + actual = attn_module(x) + expected = torch.tensor(5.83984) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + def test_self_attn_forward(self, attn, x, context): + attn_module = attn() + actual = attn_module(x) + expected = torch.tensor(-1.7353) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + actual = attn_module(x, context) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + +# All expected values come after first testing that SpatialTransformer +# has the exact output as the corresponding class in d2go, then simply +# forward passing SpatialTransformer with params, random seed, and +# initialization order in this file. +class TestSpatialTransformer: + @pytest.fixture + def transformer(self, in_channels, num_heads, num_layers): + return partial( + SpatialTransformer, + in_channels=in_channels, + num_heads=num_heads, + num_layers=num_layers, + norm_groups=2, + ) + + def _unzero_output_proj(self, transformer): + """ + Output proj is initialized with zero weights due to + fixup initialization. Change to non-zero proj weights to + run unit tests with different input combinations. + """ + for p in transformer.out_projection.parameters(): + nn.init.normal_(p) + return transformer + + def test_transformer_forward_with_context( + self, transformer, x_img, context_dim, num_layers, context + ): + transformer_module = self._unzero_output_proj( + transformer(context_dims=[context_dim] * num_layers) + ) + context_list = [deepcopy(context) for _ in range(num_layers)] + actual = transformer_module(x_img, context_list) + expected = torch.tensor(2401.9578) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-3) + + def test_transformer_forward_without_context( + self, transformer, x_img, context, num_layers + ): + transformer_module = self._unzero_output_proj(transformer()) + expected = torch.tensor(-1634.7414) + actual = transformer_module(x_img) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-3) + context_list = [deepcopy(context) for _ in range(num_layers)] + actual = transformer_module(x_img, context_list) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-3) + + def test_transformer_forward_with_auto_repeated_context( + self, transformer, x_img, context_dim, num_layers, context + ): + transformer_module = self._unzero_output_proj( + transformer(context_dims=[context_dim]) + ) + context_list = [deepcopy(context) for _ in range(num_layers)] + actual = transformer_module(x_img, context_list) + expected = torch.tensor(2401.9578) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-3) + + def test_context_dims_layers_mismatch(self, transformer, context_dim, num_layers): + with pytest.raises(ValueError): + transformer(context_dims=[context_dim] * (num_layers - 1)) + + def test_forward_context_dims_layers_mismatch( + self, transformer, context, context_dim, num_layers + ): + transformer_module = transformer(context_dims=[context_dim] * num_layers) + context_list = [deepcopy(context) for _ in range(num_layers - 1)] + with pytest.raises(RuntimeError): + transformer_module(x_img, context_list) + + def test_transformer_forward_with_linear_proj( + self, transformer, x_img, context_dim, num_layers, context + ): + transformer_module = self._unzero_output_proj( + transformer( + context_dims=[context_dim] * num_layers, use_linear_projections=True + ) + ) + context_list = [deepcopy(context) for _ in range(num_layers)] + actual = transformer_module(x_img, context_list) + expected = torch.tensor(2401.9578) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-3) diff --git a/tests/diffusion_labs/test_vae.py b/tests/diffusion_labs/test_vae.py new file mode 100644 index 00000000..97cca40b --- /dev/null +++ b/tests/diffusion_labs/test_vae.py @@ -0,0 +1,167 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +import torch.distributions as tdist +from tests.test_utils import assert_expected, set_rng_seed +from torchmultimodal.diffusion_labs.models.vae.vae import variational_autoencoder + + +@pytest.fixture(autouse=True) +def set_seed(): + set_rng_seed(98765) + + +@pytest.fixture +def embedding_channels(): + return 6 + + +@pytest.fixture +def in_channels(): + return 2 + + +@pytest.fixture +def out_channels(): + return 5 + + +@pytest.fixture +def z_channels(): + return 3 + + +@pytest.fixture +def channels(): + return 4 + + +@pytest.fixture +def num_res_blocks(): + return 2 + + +@pytest.fixture +def channel_multipliers(): + return (1, 2, 4) + + +@pytest.fixture +def norm_groups(): + return 2 + + +@pytest.fixture +def norm_eps(): + return 1e-05 + + +@pytest.fixture +def x(in_channels): + bsize = 2 + height = 16 + width = 16 + return torch.randn(bsize, in_channels, height, width) + + +@pytest.fixture +def z(embedding_channels): + bsize = 2 + height = 4 + width = 4 + return torch.randn(bsize, embedding_channels, height, width) + + +class TestVariationalAutoencoder: + @pytest.fixture + def vae( + self, + in_channels, + out_channels, + embedding_channels, + z_channels, + channels, + norm_groups, + norm_eps, + channel_multipliers, + num_res_blocks, + ): + return variational_autoencoder( + embedding_channels=embedding_channels, + in_channels=in_channels, + out_channels=out_channels, + z_channels=z_channels, + channels=channels, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + + def test_encode(self, vae, x, embedding_channels, channel_multipliers): + posterior = vae.encode(x) + expected_shape = torch.Size( + [ + x.size(0), + embedding_channels, + x.size(2) // 2 ** (len(channel_multipliers) - 1), + x.size(3) // 2 ** (len(channel_multipliers) - 1), + ] + ) + expected_mean = torch.tensor(-3.4872) + assert_expected(posterior.mean.size(), expected_shape) + assert_expected(posterior.mean.sum(), expected_mean, rtol=0, atol=1e-4) + + expected_stddev = torch.tensor(193.3726) + assert_expected(posterior.stddev.size(), expected_shape) + assert_expected(posterior.stddev.sum(), expected_stddev, rtol=0, atol=1e-4) + + # compute kl with standard gaussian + expected_kl = torch.tensor(9.8025) + torch_kl_divergence = tdist.kl_divergence( + posterior, + tdist.Normal( + torch.zeros_like(posterior.mean), torch.ones_like(posterior.stddev) + ), + ).sum() + assert_expected(torch_kl_divergence, expected_kl, rtol=0, atol=1e-4) + + # compare sample shape + assert_expected(posterior.rsample().size(), expected_shape) + + def test_decode(self, vae, z, out_channels, channel_multipliers): + actual = vae.decode(z) + expected = torch.tensor(-156.1534) + expected_shape = torch.Size( + [ + z.size(0), + out_channels, + z.size(2) * 2 ** (len(channel_multipliers) - 1), + z.size(3) * 2 ** (len(channel_multipliers) - 1), + ] + ) + assert_expected(actual.size(), expected_shape) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + @pytest.mark.parametrize( + "sample_posterior,expected_value", [(True, -153.6517), (False, -178.8593)] + ) + def test_forward(self, vae, x, out_channels, sample_posterior, expected_value): + actual = vae(x, sample_posterior=sample_posterior).decoder_output + expected = torch.tensor(expected_value) + expected_shape = torch.Size( + [ + x.size(0), + out_channels, + x.size(2), + x.size(3), + ] + ) + assert_expected(actual.size(), expected_shape) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) diff --git a/tests/diffusion_labs/test_vae_attention.py b/tests/diffusion_labs/test_vae_attention.py new file mode 100644 index 00000000..0fabe96f --- /dev/null +++ b/tests/diffusion_labs/test_vae_attention.py @@ -0,0 +1,85 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, set_rng_seed +from torch import nn +from torchmultimodal.diffusion_labs.models.vae.attention import ( + attention_res_block, + AttentionResBlock, + VanillaAttention, +) + + +@pytest.fixture(autouse=True) +def set_seed(): + set_rng_seed(1234) + + +@pytest.fixture +def channels(): + return 64 + + +@pytest.fixture +def norm_groups(): + return 16 + + +@pytest.fixture +def norm_eps(): + return 1e-05 + + +@pytest.fixture +def x(channels): + bsize = 2 + height = 16 + width = 16 + return torch.randn(bsize, channels, height, width) + + +class TestVanillaAttention: + @pytest.fixture + def attn(self, channels): + return VanillaAttention(channels) + + def test_forward(self, x, attn): + actual = attn(x) + expected = torch.tensor(32.0883) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + assert_expected(actual.shape, x.shape) + + +class TestAttentionResBlock: + @pytest.fixture + def attn(self, channels, norm_groups, norm_eps): + return AttentionResBlock( + channels, + attn_module=nn.Identity(), + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + + def test_forward(self, x, attn): + actual = attn(x) + expected = torch.tensor(295.1067) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + assert_expected(actual.shape, x.shape) + + def test_channel_indivisible_norm_group_error(self): + with pytest.raises(ValueError): + _ = AttentionResBlock(64, nn.Identity(), norm_groups=30) + + +def test_attention_res_block(channels, x): + attn = attention_res_block(channels) + expected = torch.tensor(69.692) + actual = attn(x) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + assert_expected(actual.shape, x.shape) diff --git a/tests/diffusion_labs/test_vae_encoder_decoder.py b/tests/diffusion_labs/test_vae_encoder_decoder.py new file mode 100644 index 00000000..80d918cd --- /dev/null +++ b/tests/diffusion_labs/test_vae_encoder_decoder.py @@ -0,0 +1,267 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import pytest +import torch +from tests.test_utils import assert_expected, set_rng_seed + +from torchmultimodal.diffusion_labs.models.vae.encoder_decoder import ( + res_block, + res_block_stack, + ResNetDecoder, + ResNetEncoder, +) + + +@pytest.fixture(autouse=True) +def set_seed(): + set_rng_seed(54321) + + +@pytest.fixture +def in_channels(): + return 2 + + +@pytest.fixture +def out_channels(): + return 5 + + +@pytest.fixture +def z_channels(): + return 3 + + +@pytest.fixture +def channels(): + return 4 + + +@pytest.fixture +def num_res_blocks(): + return 2 + + +@pytest.fixture +def channel_multipliers(): + return (1, 2) + + +@pytest.fixture +def norm_groups(): + return 2 + + +@pytest.fixture +def norm_eps(): + return 1e-05 + + +@pytest.fixture +def x(in_channels): + bsize = 2 + height = 16 + width = 16 + return torch.randn(bsize, in_channels, height, width) + + +@pytest.fixture +def z(z_channels): + bsize = 2 + height = 4 + width = 4 + return torch.randn(bsize, z_channels, height, width) + + +class TestResNetEncoder: + @pytest.fixture + def encoder( + self, + in_channels, + z_channels, + channels, + num_res_blocks, + channel_multipliers, + norm_groups, + norm_eps, + ): + return partial( + ResNetEncoder, + in_channels=in_channels, + z_channels=z_channels, + channels=channels, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + + @pytest.mark.parametrize("double_z", [True, False]) + def test_forward_dims(self, encoder, x, z_channels, channel_multipliers, double_z): + encoder_module = encoder(double_z=double_z) + output = encoder_module(x) + assert_expected( + output.size(), + torch.Size( + [ + x.size(0), + z_channels * (2 if double_z else 1), + x.size(2) // 2 ** (len(channel_multipliers) - 1), + x.size(3) // 2 ** (len(channel_multipliers) - 1), + ] + ), + ) + + def test_forward(self, encoder, x): + encoder_module = encoder() + actual = encoder_module(x) + expected = torch.tensor(126.5277) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + def test_channel_indivisble_norm_group_error(self, encoder): + with pytest.raises(ValueError): + _ = encoder(norm_groups=7) + + +class TestResNetDecoder: + @pytest.fixture + def decoder( + self, + out_channels, + z_channels, + channels, + num_res_blocks, + channel_multipliers, + norm_groups, + norm_eps, + ): + return partial( + ResNetDecoder, + out_channels=out_channels, + z_channels=z_channels, + channels=channels, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + + @pytest.mark.parametrize("output_alpha_channel", [True, False]) + def test_forward_dims( + self, decoder, z, out_channels, channel_multipliers, output_alpha_channel + ): + decoder_module = decoder(output_alpha_channel=output_alpha_channel) + output = decoder_module(z) + assert_expected( + output.size(), + torch.Size( + [ + z.size(0), + out_channels + (1 if output_alpha_channel else 0), + z.size(2) * 2 ** (len(channel_multipliers) - 1), + z.size(3) * 2 ** (len(channel_multipliers) - 1), + ] + ), + ) + + def test_forward(self, decoder, z): + decoder_module = decoder() + actual = decoder_module(z) + expected = torch.tensor(-10.0260) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + def test_forward_alpha_channel(self, decoder, z): + decoder_module = decoder(output_alpha_channel=True) + actual = decoder_module(z) + expected = torch.tensor(-16.2157) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + def test_channel_indivisble_norm_group_error(self, decoder): + with pytest.raises(ValueError): + _ = decoder(norm_groups=7) + + +@pytest.mark.parametrize("out_channels,expected_value", [(2, 52.2716), (4, 152.8285)]) +def test_res_block(x, out_channels, expected_value): + in_channels = x.size(1) + res_block_module = res_block(in_channels, out_channels, dropout=0.3, norm_groups=1) + actual = res_block_module(x) + expected = torch.tensor(expected_value) + assert_expected( + actual.size(), + torch.Size( + [ + x.size(0), + out_channels, + x.size(2), + x.size(3), + ] + ), + ) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + +@pytest.mark.parametrize( + "needs_upsample,needs_downsample,expected_value", + [(False, True, 28.02428), (False, False, 382.8569), (True, False, 581.62414)], +) +def test_res_block_stack( + x, + in_channels, + channels, + num_res_blocks, + needs_upsample, + needs_downsample, + expected_value, +): + res_block_stack_module = res_block_stack( + in_channels=in_channels, + out_channels=channels, + num_blocks=num_res_blocks, + dropout=0.1, + needs_upsample=needs_upsample, + needs_downsample=needs_downsample, + norm_groups=1, + ) + actual = res_block_stack_module(x) + expected = torch.tensor(expected_value) + if needs_upsample: + size_multipler = 2 + elif needs_downsample: + size_multipler = 0.5 + else: + size_multipler = 1 + assert_expected( + actual.size(), + torch.Size( + [ + x.size(0), + channels, + int(x.size(2) * size_multipler), + int(x.size(3) * size_multipler), + ] + ), + ) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + +def test_res_block_stack_exception( + in_channels, + channels, + num_res_blocks, +): + with pytest.raises(ValueError): + _ = res_block_stack( + in_channels=in_channels, + out_channels=channels, + num_blocks=num_res_blocks, + needs_upsample=True, + needs_downsample=True, + ) diff --git a/tests/diffusion_labs/test_vae_residual_sampling.py b/tests/diffusion_labs/test_vae_residual_sampling.py new file mode 100644 index 00000000..44ac41e7 --- /dev/null +++ b/tests/diffusion_labs/test_vae_residual_sampling.py @@ -0,0 +1,78 @@ +#!/usr/bin/env fbpython +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import pytest +import torch +from tests.test_utils import assert_expected, set_rng_seed +from torchmultimodal.diffusion_labs.models.vae.residual_sampling import ( + Downsample2D, + Upsample2D, +) + + +@pytest.fixture(autouse=True) +def set_seed(): + set_rng_seed(54321) + + +@pytest.fixture +def in_channels(): + return 2 + + +@pytest.fixture +def x(in_channels): + bsize = 2 + height = 16 + width = 16 + return torch.randn(bsize, in_channels, height, width) + + +def test_upsample(in_channels, x): + upsampler = Upsample2D(channels=in_channels) + actual = upsampler(x) + expected = torch.tensor(-350.5232) + assert_expected( + actual.size(), + torch.Size( + [ + x.size(0), + x.size(1), + x.size(2) * 2, + x.size(3) * 2, + ] + ), + ) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) + + +class TestDownsample: + @pytest.fixture + def downsampler_fn(self, in_channels): + return partial(Downsample2D, channels=in_channels) + + @pytest.mark.parametrize( + "asymmetric_padding,expected_value", [(True, -18.3393), (False, -28.8385)] + ) + def test_downsample(self, downsampler_fn, x, asymmetric_padding, expected_value): + downsampler = downsampler_fn(asymmetric_padding=asymmetric_padding) + actual = downsampler(x) + expected = torch.tensor(expected_value) + assert_expected( + actual.size(), + torch.Size( + [ + x.size(0), + x.size(1), + x.size(2) // 2, + x.size(3) // 2, + ] + ), + ) + assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) diff --git a/torchmultimodal/diffusion_labs/models/ldm/__init__.py b/torchmultimodal/diffusion_labs/models/ldm/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/ldm/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/models/ldm/ldm.py b/torchmultimodal/diffusion_labs/models/ldm/ldm.py new file mode 100644 index 00000000..f0748dac --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/ldm/ldm.py @@ -0,0 +1,712 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torchmultimodal.diffusion_labs.models.adm_unet.adm import ADMStack +from torchmultimodal.diffusion_labs.models.ldm.spatial_transformer import ( + SpatialTransformer, +) +from torchmultimodal.diffusion_labs.models.vae.res_block import adm_cond_proj, ResBlock +from torchmultimodal.diffusion_labs.models.vae.residual_sampling import ( + Downsample2D, + Upsample2D, +) +from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput +from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm +from torchmultimodal.modules.layers.position_embedding import ( + SinusoidalPositionEmbeddings, +) +from torchmultimodal.utils.common import init_module_parameters_to_zero + + +class LDMUNet(nn.Module): + """Implements the UNet used by Latent Diffusion Models (LDMs). Composes all + the blocks for the downsampling encoder, bottleneck, and upsampling encoder in + the LDMUNet. Constructs the network by adding residual blocks, spatial transformer + blocks, and up/downsampling blocks for every layer based on user specifications. + + Follows the architecture described in "High-Resolution Image Synthesis with Latent + Diffusion Models" (https://arxiv.org/abs/2112.10752). + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/diffusionmodules/openaimodel.py#L413 + + Overall structure: + time -> time_embedding -> t + x, t, context -> encoder -> bottleneck -> decoder -> out + + Attributes: + in_channels (int): number of input channels. + model_channels (int): base channel count for the model. + out_channels (int): number of output channels. + num_res_blocks_per_level (Union[int, Sequence[int]]): number of residual + blocks per level. If an integer, then same number of blocks used for + each level. If a sequence of integers, then the sequence length should + be same as the length of `channel_multipliers`. + attention_resolutions (Sequence[int]): sequence of downsampling rates at + which attention will be performed. For example, if this contains 2, + then at 2x downsampling, attention will be added in both the down + and up blocks. + channel_multipliers (Sequence[int]): list of channel multipliers used by the encoder. + Decoder uses them in reverse order. Defaults to [1, 2, 4, 8]. + context_dims (Sequence[int], optional): list of dimensions of the conditional + context tensors. This enables sequential attention support by adding new conditioning + models to the end of context list. If length of `context_dims` is not the same as + `num_transformer_layers`, use each element of `context_dims` + `int(len(context_dims)/num_transformer_layers)` times. Defaults to None. + use_res_block_updown (bool): if True, use up/down residual blocks for upsampling. + Defaults to False. + scale_shift_conditional (bool): if True, splits conditional embedding into two separate + projections, and adds to hidden state as Norm(h)(w + 1) + b, as described in + Appendix A in "Improved Denoising Diffusion Probabilistic Models" + (https://arxiv.org/abs/2102.09672), in resdidual blocks. + Defaults to False. + num_attention_heads (int, optional): Number of attention heads used in spatial + transformer. If None, then `num_channels_per_attention_head` must be provided. + Defaults to None. + num_channels_per_attention_head (int, optional): Number of channels for each attention + head. If None, then `num_heads` must be provided. Defaults to None. + num_transformer_layers (Union[int, Sequence[int]]): Number of layers in the spatial transformer. + If an integer, then the same number of layers is used for + each block. If a sequence of integers, then the sequence length should + be same as the length of `channel_multipliers`. Defaults to 1. + use_linear_projection_in_transformer (bool): If True, use linear input and output + projections in spatial transformer, instead 1x1 convolutions. Defaults to False. + dropout (float): Dropout value passed to residual blocks. Defaults to 0.0. + embed_input_size (bool): if True, embed input image size and add to timestep + embedding. coordinate_embedding_dim must be positive + embed_target_size (bool): if True, embed target image size and add to timestep + embedding. coordinate_embedding_dim must be positive + embed_crop_params (bool): if True, embed crop parameters and add to timestep + embedding. coordinate_embedding_dim must be positive + coordinate_embedding_dim (int): embedding dimension for sinusoidal embedding + of coordinates (e.g. input and target size). If > 0, embed_input_size, + embed_target_size, or embed_crop_params must be True + pooled_text_embedding_dim (int): dimension of the pooled text embedding + to be added to timestep embedding + + Args: + x (Tensor): input Tensor of shape [b, in_channels, h, w] + timestep (Tensor): diffusion timesteps of shape [b, ] + context_list (Sequence[Tensor], optional): Optional list of context Tensors, + each of shape [b, seq_len_i, context_dim_i]. Defaults to None + + Raises: + ValueError: If `num_res_blocks_per_level` and `channel_multipliers` do not + have the same length. + ValueError: If `num_transformer_layers` and `channel_multipliers` do not + have the same length. + ValueError: If both `num_attention_heads` and `num_channels_per_attention_head` + are None or both are set. + ValueError: If `model_channels` * `channel_multipliers[0]` is not divisible + by 32 (number of norm groups). + RuntimeError: If length of `context_list` in forward is not same as length of `context_dims`. + RuntimeError: If context tensor ar index `i` does not have embed dim equal to `context_dims[i]`. + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks_per_level: Union[int, Sequence[int]], + attention_resolutions: Sequence[int], + channel_multipliers: Sequence[int] = ( + 1, + 2, + 4, + 8, + ), + context_dims: Optional[Sequence[int]] = None, + use_res_block_updown: bool = False, + scale_shift_conditional: bool = False, + num_attention_heads: Optional[int] = None, + num_channels_per_attention_head: Optional[int] = None, + num_transformer_layers: Union[int, Sequence[int]] = 1, + use_linear_projection_in_transformer: bool = False, + dropout: float = 0.0, + embed_input_size: bool = False, + embed_target_size: bool = False, + embed_crop_params: bool = False, + coordinate_embedding_dim: int = 0, + pooled_text_embedding_dim: int = 0, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + num_res_blocks_per_level_list: Sequence[int] = [] + if isinstance(num_res_blocks_per_level, int): + num_res_blocks_per_level_list = [num_res_blocks_per_level] * len( + channel_multipliers + ) + elif isinstance(num_res_blocks_per_level, Sequence): + num_res_blocks_per_level_list = num_res_blocks_per_level + if len(num_res_blocks_per_level_list) != len(channel_multipliers): + raise ValueError( + "Expected `num_res_blocks_per_level` to have exactly the same length as `channel_multipliers`" + f"({len(channel_multipliers)}), but got {len(num_res_blocks_per_level)}." + ) + + num_transformer_layers_list: Sequence[int] = ( + [num_transformer_layers] * len(channel_multipliers) + if isinstance(num_transformer_layers, int) + else num_transformer_layers + ) + + if len(num_transformer_layers_list) != len(channel_multipliers): + raise ValueError( + "Expected `num_transformer_layers` to have exactly the same length as `channel_multipliers`" + f"({len(channel_multipliers)}), but got {len(num_transformer_layers_list)}." + ) + + if num_attention_heads is None and num_channels_per_attention_head is None: + raise ValueError( + "Only one of `num_attention_heads` or `num_channels_per_attention_head` " + "can be set, but none were set." + ) + elif ( + num_attention_heads is not None + and num_channels_per_attention_head is not None + ): + raise ValueError( + "Only one of `num_attention_heads` and `num_channels_per_attention_head` can be set," + " but both were set." + ) + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks_per_level = num_res_blocks_per_level_list + self.attention_resolutions = attention_resolutions + self.channel_multipliers = channel_multipliers + self.use_res_block_updown = use_res_block_updown + self.scale_shift_conditional = scale_shift_conditional + self.dropout = dropout + self.context_dims = context_dims + self.num_attention_heads = num_attention_heads + self.num_channels_per_attention_head = num_channels_per_attention_head + self.num_transformer_layers = num_transformer_layers_list + self.use_linear_projection_in_transformer = use_linear_projection_in_transformer + + # time embedding dim is 4 * model_channels to match the original implementation + self.time_embedding_dim = model_channels * 4 + self.time_embedding = self._create_time_embedding() + + # Additional embeddings to add in for SDXL + self.pooled_text_embedding_dim = pooled_text_embedding_dim + assert self.pooled_text_embedding_dim >= 0 + + self.embed_input_size = embed_input_size + self.embed_target_size = embed_target_size + self.embed_crop_params = embed_crop_params + # Multiply by 2 to account for (x, y) coordinates of each param + self.num_coordinates = 2 * ( + self.embed_input_size + self.embed_target_size + self.embed_crop_params + ) + self.coordinate_embedding_dim = coordinate_embedding_dim + assert self.coordinate_embedding_dim >= 0 + assert (self.num_coordinates > 0) == ( + self.coordinate_embedding_dim > 0 + ), "Coordinate embedding can only be used when coordinates are provided" + + if self.num_coordinates and self.coordinate_embedding_dim: + self.pos_embedding = self._create_pos_embedding() + self.pooled_text_embedding_dim = pooled_text_embedding_dim + if ( + self.num_coordinates and self.coordinate_embedding_dim + ) or self.pooled_text_embedding_dim: + self.add_embedding = self._create_add_embedding() + + # TODO: Add support for label embeddings + self.down, down_channels, max_resolution = self._create_downsampling_encoder() + self.bottleneck = self._create_bottleneck( + down_channels[-1], num_layers=self.num_transformer_layers[-1] + ) + self.up = self._create_upsampling_decoder(down_channels, max_resolution) + # input to the output block will have model_channels * channel_multipliers[0] channels + self.out = self._create_out_block(model_channels * channel_multipliers[0]) + + def _create_time_embedding(self) -> nn.Module: + return nn.Sequential( + SinusoidalPositionEmbeddings(embed_dim=self.model_channels), + nn.Linear(self.model_channels, self.time_embedding_dim), + nn.SiLU(), + nn.Linear(self.time_embedding_dim, self.time_embedding_dim), + ) + + def _create_pos_embedding(self) -> nn.Module: + return SinusoidalPositionEmbeddings(embed_dim=self.coordinate_embedding_dim) + + def _create_add_embedding(self) -> nn.Module: + return nn.Sequential( + nn.Linear( + self.coordinate_embedding_dim * self.num_coordinates + + self.pooled_text_embedding_dim, + self.time_embedding_dim, + ), + nn.SiLU(), + nn.Linear(self.time_embedding_dim, self.time_embedding_dim), + ) + + def _add_to_time_embedding( + self, time_embedding: Tensor, additional_embeddings: Dict[str, Tensor] + ) -> Tensor: + add_embeddings = [] + if self.pooled_text_embedding_dim: + assert ( + "pooled_text_embed" in additional_embeddings + ), "pooled text embedding not provided" + pooled_text_embed = additional_embeddings["pooled_text_embed"] + + add_embeddings.append(pooled_text_embed) + + if self.num_coordinates and self.coordinate_embedding_dim: + assert ( + "coordinates" in additional_embeddings + ), "coordinates to add to timestep embedding not provided" + assert ( + additional_embeddings["coordinates"].shape[1] == self.num_coordinates + ), f"Unexpected number of coordinate values: {additional_embeddings['coordinates'].shape[1]} != {self.num_coordinates}" + coordinates = additional_embeddings["coordinates"].flatten() + + pos_embed = self.pos_embedding(coordinates) + pos_embed = pos_embed.reshape( + ( + time_embedding.shape[0], + self.coordinate_embedding_dim * self.num_coordinates, + ) + ) + + add_embeddings.append(pos_embed) + + if add_embeddings: + add_embed = torch.concat(add_embeddings, dim=-1) + add_embed = self.add_embedding(add_embed) + + time_embedding = time_embedding + add_embed + + return time_embedding + + def _create_downsampling_encoder(self) -> Tuple[nn.ModuleList, List[int], int]: + """Returns a nn.ModuleList of downsampling residual blocks, channel count + for decoder connections and max downsampling rate. + """ + # Keep track of output channels of every block for thru connections to decoder + down_channels = [] + # Use ADMStack for conv layer so we can pass in conditional inputs and ignore them + init_conv = ADMStack() + init_conv.append_simple_block( + nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=1) + ) + down_channels.append(self.model_channels) + + encoder_blocks = nn.ModuleList([init_conv]) + channels_list = tuple( + [ + self.model_channels * multiplier + for multiplier in [1] + list(self.channel_multipliers) + ] + ) + # keep a track of downsampling rate so we can add attention and + # return the max downsampling rate + downsampling_rate = 1 + num_resolutions = len(self.channel_multipliers) + for level_idx in range(num_resolutions): + block_in = channels_list[level_idx] + block_out = channels_list[level_idx + 1] + res_blocks_list, res_block_channels = res_block_adm_stack( + block_in, + block_out, + self.time_embedding_dim, + self.num_res_blocks_per_level[level_idx], + self.num_transformer_layers[level_idx], + self.scale_shift_conditional, + self.dropout, + attention_fn=self._create_attention + if downsampling_rate in self.attention_resolutions + else None, + ) + # add residual blocks for each level to encoder blocks + encoder_blocks.extend(res_blocks_list) + # add residual block channels for each level to down channels + down_channels.extend(res_block_channels) + + # add downsampling blocks for all levels except the last one + if level_idx != num_resolutions - 1: + downsampling_block = ADMStack() + # use residual block for downsampling + if self.use_res_block_updown: + downsampling_block.append_residual_block( + res_block( + block_out, + block_out, + self.time_embedding_dim, + self.scale_shift_conditional, + self.dropout, + use_downsample=True, + ) + ) + # use conv downsampling + else: + downsampling_block.append_simple_block( + Downsample2D(block_out, asymmetric_padding=False) + ) + encoder_blocks.append(downsampling_block) + down_channels.append(block_out) + # increase downsampling rate for next level + downsampling_rate *= 2 + return encoder_blocks, down_channels, downsampling_rate + + def _create_upsampling_decoder( + self, down_channels: List[int], max_downsampling_rate: int + ) -> nn.ModuleList: + """ + Args: + down_channels (List[int]): list of down sample channels for + through connections from encoder. + max_downsampling_rate (int): max downsampling rate from the decoder. + """ + decoder_blocks = nn.ModuleList() + reversed_channels_list = tuple( + reversed( + [ + self.model_channels * multiplier + for multiplier in list(self.channel_multipliers) + + [self.channel_multipliers[-1]] + ] + ) + ) + reversed_res_blocks_per_level = tuple(reversed(self.num_res_blocks_per_level)) + reversed_num_transformer_layers = tuple(reversed(self.num_transformer_layers)) + downsampling_rate = max_downsampling_rate + num_resolutions = len(self.channel_multipliers) + for level_idx in range(num_resolutions): + block_in = reversed_channels_list[level_idx] + block_out = reversed_channels_list[level_idx + 1] + res_blocks_list, _ = res_block_adm_stack( + block_in, + block_out, + self.time_embedding_dim, + # Code ref uses + 1 res blocks in upsampling decoder to add + # extra res block before upsampling + reversed_res_blocks_per_level[level_idx] + 1, + reversed_num_transformer_layers[level_idx], + self.scale_shift_conditional, + self.dropout, + attention_fn=self._create_attention + if downsampling_rate in self.attention_resolutions + else None, + # For through connection from the encoder + additional_input_channels=down_channels, + ) + decoder_blocks.extend(res_blocks_list) + + if level_idx != num_resolutions - 1: + # Key difference between encoder and decoder is that encoder performs + # downsampling as a separate block whose outputs are saved in forward. + # On the other hand in the decoder, upsampling and the last residual + # block are done in a single forward pass operation. + # Hence in encoder, downsampling is wrapped in its own `ADMStack` while in the + # decoder, upsampling and final residual block are wrapped in the same `ADMStack`. + last_res_stack = decoder_blocks[-1] + if self.use_res_block_updown: + last_res_stack.append_residual_block( + res_block( + block_out, + block_out, + self.time_embedding_dim, + self.scale_shift_conditional, + self.dropout, + use_upsample=True, + ) + ) + else: + last_res_stack.append_simple_block(Upsample2D(block_out)) + downsampling_rate = downsampling_rate // 2 + return decoder_blocks + + def _create_bottleneck(self, channels: int, num_layers: int) -> ADMStack: + bottleneck = ADMStack() + bottleneck.append_residual_block( + res_block( + channels, + channels, + self.time_embedding_dim, + self.scale_shift_conditional, + self.dropout, + ) + ) + bottleneck.append_attention_block(self._create_attention(channels, num_layers)) + bottleneck.append_residual_block( + res_block( + channels, + channels, + self.time_embedding_dim, + self.scale_shift_conditional, + self.dropout, + ) + ) + return bottleneck + + def _create_out_block(self, channels: int) -> nn.Module: + conv = nn.Conv2d(channels, self.out_channels, kernel_size=3, padding=1) + # Initialize output projection with zero weight and bias. This helps with + # training stability. Initialization trick from Fixup Initialization. + # https://arxiv.org/abs/1901.09321 + init_module_parameters_to_zero(conv) + return nn.Sequential( + self._create_norm(channels), + nn.SiLU(), + conv, + ) + + def _create_norm(self, channels: int) -> nn.Module: + # Original LDM implementation hardcodes norm groups to 32 + return Fp32GroupNorm(32, channels) + + def _get_num_attention_heads(self, channels: int) -> int: + # if num attention heads is not given, then calculate it by + # dividing channles with num_channels_per_attention_head + if self.num_channels_per_attention_head is not None: + return channels // self.num_channels_per_attention_head + elif self.num_attention_heads is not None: + return self.num_attention_heads + # Should never happen. Adding to make Pyre happy + return 1 + + def _create_attention(self, channels: int, num_layers: int) -> nn.Module: + # original LDM implementation does not pass dropout to SpatialTransformer + # from UNet, instead just hardcodes it. + return SpatialTransformer( + in_channels=channels, + num_heads=self._get_num_attention_heads(channels), + num_layers=num_layers, + context_dims=self.context_dims, + use_linear_projections=self.use_linear_projection_in_transformer, + ) + + def forward( + self, + x: Tensor, + timestep: Tensor, + context_list: Optional[Sequence[Tensor]] = None, + additional_embeddings: Optional[Dict[str, Tensor]] = None, + ) -> Tensor: + # Check if context list is provided, then every context embedding has + # dim as specified in `self.context_dims`. + if isinstance(context_list, Sequence) and isinstance( + self.context_dims, Sequence + ): + # needed to keep Pyre happy + context_dims = self.context_dims or [] + if len(context_list) != len(context_dims): + raise RuntimeError( + f"Expected {len(context_dims)} context tensors. Got {len(context_list)}." + ) + for i in range(len(context_list)): + if context_list[i].size()[-1] != context_dims[i]: + raise RuntimeError( + f"Expect context tensor at index {i} to have {context_dims[i]} dim, " + f"got dim {context_list[i].size()[-1]}" + ) + time_embedding = self.time_embedding(timestep) + + # Add additional conditions to the time embedding if provided + if additional_embeddings: + time_embedding = self._add_to_time_embedding( + time_embedding, additional_embeddings + ) + + h = x + hidden_states = [] + for block in self.down: + h = block(h, time_embedding, context_list) + hidden_states.append(h) + h = self.bottleneck(h, time_embedding, context_list) + for block in self.up: + # through connections from the encoder + h = torch.cat([h, hidden_states.pop()], dim=1) + h = block(h, time_embedding, context_list) + return self.out(h) + + +def res_block_adm_stack( + in_channels: int, + out_channels: int, + time_embedding_dim: int, + num_blocks: int, + num_layers: int, + scale_shift_conditional: bool, + dropout: float = 0.0, + attention_fn: Optional[Callable[[int, int], nn.Module]] = None, + additional_input_channels: Optional[List[int]] = None, +) -> Tuple[nn.ModuleList, List[int]]: + """Create a stack of residual blocks wrapped in ADMStack. + + Args: + in_channels (int): input channels + out_channels (int): output channels + time_embedding_dim (int): dimension of time embeddings + num_blocks (int): number of residual blocks + scale_shift_conditional (bool): If True, scale shift conditionals + dropout (float, optional): dropout rate. Defaults to 0.0. + attention_fn (Callable[[int], nn.Module], optional): function to be called + to build the attention module. If None, no attention module is created. + Defaults to None. + additional_input_channels (List[int], optional): additional input channels for + through connections from UNet encoder to decoder. If None, no additional + input channels are added. Defaults to None. + + Returns: + nn.ModuleList: list of residual blocks. + List[int]: list of output channel sizes from each block. + """ + blocks = nn.ModuleList() + block_channels = [] + block_in, block_out = in_channels, out_channels + for _ in range(num_blocks): + stack = ADMStack() + stack.append_residual_block( + res_block( + block_in + + (additional_input_channels.pop() if additional_input_channels else 0), + block_out, + time_embedding_dim, + scale_shift_conditional, + dropout, + ) + ) + block_in = block_out + if attention_fn is not None: + stack.append_attention_block(attention_fn(block_in, num_layers)) + blocks.append(stack) + block_channels.append(block_out) + return blocks, block_channels + + +def res_block( + in_channels: int, + out_channels: int, + time_embedding_dim: int, + scale_shift_conditional: bool, + dropout: float, + use_upsample: bool = False, + use_downsample: bool = False, +) -> ResBlock: + """Create one residual block based on parameters. + + Args: + in_channels (int): input channels + out_channels (int): output channels + time_embedding_dim (int): dimension of time embeddings + scale_shift_conditional (bool): If True, scale shift conditionals + dropout (float): dropout rate + use_upsample (bool): If True, use upsampling in resdiual block + use_downsample (bool): If True, use downsampling in resdiual block + + Returns: + ResBlock: Residual block module + + Raises: + ValueError: If both `use_upsample` and `use_downsample` are True. + + """ + if use_downsample and use_upsample: + raise ValueError("Cannot use both upsample and downsample in res block") + res_block_partial = partial( + ResBlock, + in_channels=in_channels, + out_channels=out_channels, + pre_outconv_dropout=dropout, + scale_shift_conditional=scale_shift_conditional, + use_upsample=use_upsample, + use_downsample=use_downsample, + cond_proj=adm_cond_proj( + dim_cond=time_embedding_dim, + cond_channels=out_channels, + scale_shift_conditional=scale_shift_conditional, + ), + ) + if in_channels != out_channels: + residual_block = res_block_partial( + skip_conv=nn.Conv2d(in_channels, out_channels, kernel_size=1) + ) + else: + residual_block = res_block_partial() + # Initialize residual block's out projection with zero weight and bias. + # This helps with training stability. Initialization trick from Fixup + # Initialization : https://arxiv.org/abs/1901.09321 + init_module_parameters_to_zero(residual_block.out_block[-1]) + return residual_block + + +class LDMModel(nn.Module): + """Implements the LDM model used by Latent Diffusion Models (LDMs). This is a + lightweight class that is responsible for composing the LDMUNet and handles + building the input conditioning tensors from the passed context dictionary. This + allows us to conveniently use the DDPM, DDIM, CFGuidance, etc modules across + different models. + + Attributes: + unet (LDMUNet): Initialized UNet used by the model. + cond_keys (Sequence[str]): Ordered sequence of conditioning keys to build + conditional input for cross-attention in unet model. Defaults to tuple(). + additional_cond_keys (Optional[Sequence[str]]): List of conditioning keys to + be passed as additional conditioning in unet model. These are usually + projected onto and pooled with the timestep embeddings. Defaults to None. + + Args: + x (Tensor): input Tensor of shape [b, in_channels, h, w] + timestep (Tensor): diffusion timesteps of shape [b, ] + conditional_inputs (Dict[str, Tensor], optional): Optional dictionary of + context tensors. Key is the conditioning key, value is a tensor of shape + [b, seq_len, context_dim]. Defaults to None. + + Raises: + KeyError: If any of the keys in `cond_keys` or `additional_embedding_keys` are + not present in `conditional_inputs`. + RuntimeError: If conditional input does not have 3 dims. + """ + + def __init__( + self, + unet: LDMUNet, + cond_keys: Sequence[str] = tuple(), + additional_cond_keys: Optional[Sequence[str]] = None, + ): + super().__init__() + self.model = unet + self.cond_keys = cond_keys + self.additional_cond_keys = additional_cond_keys + + def forward( + self, + x: Tensor, + timesteps: Tensor, + conditional_inputs: Optional[Dict[str, Tensor]] = None, + ): + context_list = None + additional_embeddings = None + if conditional_inputs is not None: + context_list = [conditional_inputs[k] for k in self.cond_keys] + for c in context_list: + if len(c.size()) != 3: + raise RuntimeError( + f"Expected context tensor to have 3 dims, got {len(c.size())}." + ) + + if self.additional_cond_keys is not None: + additional_embeddings = { + k: conditional_inputs[k] for k in self.additional_cond_keys + } + + h = self.model(x, timesteps, context_list, additional_embeddings) + return DiffusionOutput(prediction=h) diff --git a/torchmultimodal/diffusion_labs/models/ldm/spatial_transformer.py b/torchmultimodal/diffusion_labs/models/ldm/spatial_transformer.py new file mode 100644 index 00000000..a36fe322 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/ldm/spatial_transformer.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence + +import torch +from torch import nn, Tensor +from torchmultimodal.modules.layers.activation import GEGLU +from torchmultimodal.modules.layers.multi_head_attention import ( + MultiHeadAttentionWithCache, +) +from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm, Fp32LayerNorm +from torchmultimodal.utils.common import init_module_parameters_to_zero + + +class SpatialTransformerCrossAttentionLayer(nn.Module): + """Transformer encoder layer with cross-attention mechanism. This layer contains + 2 attention blocks that use PyTorch's scaled dot product attention. The first attention + block performs self attention block, while the second performs cross-attention. If + `context_dim` is not set or `context` is not passed, the second block defaults to + self attention. + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/attention.py#L196 + + Attributes: + d_model (int): size of hidden dimension of input + num_heads (int): number of attention heads + context_dim (int, optional): size of context embedding. If None, + use self attention. Defaults to None. + dropout (float): Dropout to apply post attention layers. + Defaults to 0. + attention_dropout (float): Dropout to apply to scaled dot product + attention. Defaults to 0. + + Args: + x (Tensor): input Tensor of shape [b, seq_len, d_model] + context (Tensor, optional): Context tensor of shape + [b, seq_len, context_dim]. Defaults to None. + """ + + def __init__( + self, + d_model: int, + num_heads: int, + context_dim: Optional[int] = None, + dropout: float = 0.0, + attention_dropout: float = 0.0, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + # Optional context is added only for parity. + # TODO: Remove optional context if this code path is + # not used by genai use cases + self.use_context = context_dim is not None + self.self_attn_layernorm = Fp32LayerNorm(d_model) + self.self_attn = MultiHeadAttentionWithCache( + dim_q=d_model, + dim_kv=d_model, + num_heads=num_heads, + dropout=attention_dropout, + add_bias=False, + ) + self.self_attn_dropout = nn.Dropout(dropout) + + # If no context is passed, then cross-attention blocks end up performing self-attention + self.cross_attn_layernorm = Fp32LayerNorm(d_model) + self.cross_attn = MultiHeadAttentionWithCache( + dim_q=d_model, + # defaults to self attention if context dim not provided + dim_kv=context_dim if context_dim else d_model, + num_heads=num_heads, + dropout=attention_dropout, + add_bias=False, + ) + self.cross_attn_dropout = nn.Dropout(dropout) + + # scaling the projection dimension by 4 to match the logic in the original implementation + projection_dim = d_model * 4 + self.feed_forward_block = nn.Sequential( + Fp32LayerNorm(d_model), + # Projection dim is scaled by 2 to perform the GELU operation which chunks the + # input projection into 2 parts and combines them to obtain the activation. + nn.Linear(d_model, projection_dim * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(projection_dim, d_model), + ) + + def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: + if not self.use_context: + context = None + + h = self.self_attn_layernorm(x) + h = self.self_attn_dropout(self.self_attn(query=h, key=h, value=h)) + x + h_res = h + h = self.cross_attn_layernorm(h) + context = context if context is not None else h + h = ( + self.cross_attn_dropout( + self.cross_attn(query=h, key=context, value=context) + ) + + h_res + ) + h = self.feed_forward_block(h) + h + return h + + +class SpatialTransformer(nn.Module): + """Transformer block with cross-attention mechanism that operates on + image-like data. First, it flattens the spatial dimensions of the image + to shape (batch_size x (height * width) x num_channels) and applies an + input projection. Next, the projected input is passed ta block of stacked + transformer cross attention layers. Finally, the output goes through another + projection and is reshaped back to an image. + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/attention.py#L218 + + Attributes: + in_channls(int): number of channels in input image. + num_heads (int): number of attention heads. + num_layers (int): number of transformer encoder layers. + context_dims (Sequence[int], optional): Size of context embedding for + every transformer layer. If len(context_dim) < num_layers, expand context_dims + to have length num_layers. If None, use self attention for every layer. + Defaults to None. + use_linear_projections (bool, optional): If True, use linear input and output + projections instead of 1x1 conv projections. Defaults to False. + dropout (float): Dropout to apply post attention layers. + Defaults to 0. + attention_dropout (float): Dropout to apply to scaled dot product + attention. Defaults to 0. + norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. + norm_eps (float): epsilon used in the GroupNorm layer. Defaults to 1e-6. + + Args: + x (Tensor): input Tensor of shape [b, seq_len, d_model] + context (Sequence[Tensor], optional): List of context tensors of shape + [b, seq_len, context_dim] each. Must be equal to the number of + transformer layers. Defaults to None. + + Raises: + ValueError: If `num_layers` is not a multiple of length of `context_dims`. + RuntimeError: If `len(self.transformer_layers)` is not a multiple of length of `context`. + """ + + def __init__( + self, + in_channels: int, + num_heads: int, + num_layers: int, + context_dims: Optional[Sequence[int]] = None, + use_linear_projections: bool = False, + dropout: float = 0.0, + attention_dropout: float = 0.0, + norm_groups: int = 32, + norm_eps: float = 1e-6, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + # Optional context is added only for parity with oss implementation. + # TODO: Remove optional context if this code path is + # not used by genai use cases + self.use_context = context_dims is not None + if context_dims is None: + context_dims = [context_dims] * num_layers + elif num_layers != len(context_dims): + assert isinstance(context_dims, list) + if num_layers % len(context_dims) != 0: + raise ValueError( + "`num_layers` must be a multiple of the length of `context_dims`." + ) + + repeating_factor = int(num_layers / len(context_dims)) + print( + f"WARNING: context dims {context_dims} of length {len(context_dims)} does not match " + f"'num_layers'={num_layers}. Expanding context_dims to {context_dims * repeating_factor}." + ) + context_dims = context_dims * repeating_factor + + self.use_linear_projections = use_linear_projections + + self.norm = Fp32GroupNorm( + num_groups=norm_groups, num_channels=in_channels, eps=norm_eps + ) + + # Initialize input and output projections. If using linear projections, both + # projections are initialized with nn.Linear, otherwise use 1x1 convolutions. + if self.use_linear_projections: + self.in_projection = nn.Linear(in_channels, in_channels) + self.out_projection = nn.Linear(in_channels, in_channels) + else: + self.in_projection = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.out_projection = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + # Initialize out projection with zero weight and bias. This helps with + # training stability. Initialization trick from Fixup Initialization. + # https://arxiv.org/abs/1901.09321 + init_module_parameters_to_zero(self.out_projection) + + self.transformer_layers = nn.ModuleList( + [ + SpatialTransformerCrossAttentionLayer( + d_model=in_channels, + num_heads=num_heads, + context_dim=context_dims[i], + dropout=dropout, + attention_dropout=attention_dropout, + ) + for i in range(num_layers) + ] + ) + + def forward( + self, + x: Tensor, + context: Optional[Sequence[Tensor]] = None, + ) -> Tensor: + # If context_dims were not provided in init, default to self attention + # by setting passed context to None. + # TODO: Remove this logic if there is no case where context or context dims are None. + if not self.use_context: + context = None + if isinstance(context, Sequence) and len(context) != len( + self.transformer_layers + ): + if len(self.transformer_layers) % len(context) != 0: + raise RuntimeError( + "`len(self.transformer_layers)` must be a multiple of the length of `context`." + ) + + print( + f"WARNING: context of length {len(context)} does not match 'num_layers'={len(self.transformer_layers)}." + f" Contexts will be re-used {int(len(self.transformer_layers)/len(context))} times." + ) + + _, _, H, W = x.shape + h = self.norm(x) + # For linear projection, first reshape and then apply projection + if self.use_linear_projections: + # b * c * h * w -> b * (h * w) * c + h = torch.transpose(torch.flatten(h, start_dim=2), 1, 2) + h = self.in_projection(h) + else: # For conv projection, first apply projection and then reshape + h = self.in_projection(h) + # b * c * h * w -> b * (h * w) * c + h = torch.transpose(torch.flatten(h, start_dim=2), 1, 2) + + for i in range(len(self.transformer_layers)): + if isinstance(context, Sequence): + _context = context[i % len(context)] + else: + _context = None + + h = self.transformer_layers[i](h, context=_context) + + # For linear projection, first apply projection and then unflatten + if self.use_linear_projections: + h = self.out_projection(h) + h = torch.unflatten(torch.transpose(h, 1, 2), dim=2, sizes=(H, W)) + else: # For conv projection, first unflatten and then apply projection + h = torch.unflatten(torch.transpose(h, 1, 2), dim=2, sizes=(H, W)) + h = self.out_projection(h) + + return h + x diff --git a/torchmultimodal/diffusion_labs/models/vae/__init__.py b/torchmultimodal/diffusion_labs/models/vae/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/vae/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/diffusion_labs/models/vae/attention.py b/torchmultimodal/diffusion_labs/models/vae/attention.py new file mode 100644 index 00000000..3772a8b8 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/vae/attention.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm + + +class AttentionResBlock(nn.Module): + """Attention block in the LDM Autoencoder that consists of group norm, attention, + conv projection and a residual connection. + + Follows the architecture described in "High-Resolution Image Synthesis with Latent + Diffusion Models" (https://arxiv.org/abs/2112.10752) + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/diffusionmodules/model.py#LL150C1-L150C6 + + Attributes: + num_channels (int): channel dim expected in input, determines embedding dim of + q, k, v in attention module. Needs to be divisible by norm_groups. + attn_module (nn.Module): Module of attention mechanism to use. + norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. + norm_eps (float): epsilon used in the GroupNorm layer. Defaults to 1e-6. + + Args: + x (Tensor): input Tensor of shape [b, c, h, w] + + Raises: + ValueError: If `num_channels` is not divisible by `norm_groups`. + """ + + def __init__( + self, + num_channels: int, + attn_module: nn.Module, + norm_groups: int = 32, + norm_eps: float = 1e-6, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + if num_channels % norm_groups != 0: + raise ValueError("Channel dims need to be divisible by norm_groups") + + self.net = nn.Sequential( + OrderedDict( + [ + ("norm", Fp32GroupNorm(norm_groups, num_channels, norm_eps)), + ("attn", attn_module), + ("out", nn.Conv2d(num_channels, num_channels, kernel_size=1)), + ] + ) + ) + + def forward(self, x: Tensor) -> Tensor: + return self.net(x) + x + + +class VanillaAttention(nn.Module): + """Attention module used in the LDM Autoencoder. Similar to standard Q, k V attention, + but using 2d convolutions instead of linear projections for obtaining q, k, v tensors. + + Follows the architecture described in "High-Resolution Image Synthesis with Latent + Diffusion Models" (https://arxiv.org/abs/2112.10752) + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/diffusionmodules/model.py#LL150C1-L150C6 + + Attributes: + num_channels (int): channel dim expected in input, determines embedding dim of q, k, v. + + Args: + x (Tensor): input Tensor of shape [b, c, h, w] + """ + + def __init__(self, num_channels: int): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + self.query = nn.Conv2d(num_channels, num_channels, kernel_size=1) + self.key = nn.Conv2d(num_channels, num_channels, kernel_size=1) + self.value = nn.Conv2d(num_channels, num_channels, kernel_size=1) + + def forward(self, x: Tensor) -> Tensor: + q, k, v = self.query(x), self.key(x), self.value(x) + B, C, H, W = q.shape + # [B, C, H, W] -> [B, H*W, C] + q, k, v = (t.reshape(B, C, H * W).permute(0, 2, 1) for t in (q, k, v)) + # [B, H*W, C] + out = F.scaled_dot_product_attention(q, k, v) + # [B, H*W, C] -> [B, C, H, W] + return out.permute(0, 2, 1).reshape(B, C, H, W) + + +def attention_res_block( + channels: int, + norm_groups: int = 32, + norm_eps: float = 1e-6, +) -> AttentionResBlock: + return AttentionResBlock( + channels, VanillaAttention(channels), norm_groups, norm_eps + ) diff --git a/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py b/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py new file mode 100644 index 00000000..3dad909c --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Sequence + +import torch +from torch import nn, Tensor +from torchmultimodal.diffusion_labs.models.vae.attention import attention_res_block +from torchmultimodal.diffusion_labs.models.vae.res_block import ResBlock +from torchmultimodal.diffusion_labs.models.vae.residual_sampling import ( + Downsample2D, + Upsample2D, +) +from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm + + +class ResNetEncoder(nn.Module): + """Resnet encoder used in the LDM Autoencoder that consists of a init convolution, + downsampling resnet blocks, middle resnet blocks with attention and output convolution block + with group normalization and nonlinearity. + + Follows the architecture described in "High-Resolution Image Synthesis with Latent + Diffusion Models" (https://arxiv.org/abs/2112.10752) + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/diffusionmodules/model.py#L368 + + Attributes: + in_channels (int): number of input channels. + z_channels (int): number of latent channels. + channels (int): number of channels in the initial convolution layer. + num_res_block (int): number of residual blocks at each resolution. + channel_multipliers (Sequence[int]): list of channel multipliers. Defaults to [1, 2, 4, 8]. + dropout (float): dropout probability. Defaults to 0.0. + double_z (bool): whether to use double z_channels for images or not. Defaults to True. + norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. + norm_eps (float): epsilon used in the GroupNorm layer. Defaults to 1e-6. + + Args: + x (Tensor): input Tensor of shape [b, c, h, w] + + Raises: + ValueError: If `channels` * `channel_multipliers[-1]` is not divisible by `norm_groups`. + """ + + def __init__( + self, + in_channels: int, + z_channels: int, + channels: int, + num_res_blocks: int, + channel_multipliers: Sequence[int] = ( + 1, + 2, + 4, + 8, + ), + dropout: float = 0.0, + double_z: bool = True, + norm_groups: int = 32, + norm_eps: float = 1e-6, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + # initial convolution + self.init_conv = nn.Conv2d(in_channels, channels, kernel_size=3, padding=1) + + # downsampling block + self.down_block = nn.Sequential() + channels_list = tuple( + [channels * multiplier for multiplier in [1] + list(channel_multipliers)] + ) + num_resolutions = len(channel_multipliers) + for level_idx in range(num_resolutions): + block_in = channels_list[level_idx] + block_out = channels_list[level_idx + 1] + self.down_block.append( + res_block_stack( + block_in, + block_out, + num_res_blocks, + dropout, + needs_downsample=True + if level_idx != num_resolutions - 1 + else False, + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + ) + + mid_channels = channels_list[-1] + self.mid_block = nn.Sequential( + res_block(mid_channels, mid_channels, dropout, norm_groups, norm_eps), + attention_res_block(mid_channels, norm_groups, norm_eps), + res_block(mid_channels, mid_channels, dropout, norm_groups, norm_eps), + ) + + if mid_channels % norm_groups != 0: + raise ValueError( + "Channel dims obtained by multiplying channels with last" + " item in channel_multipliers needs to be divisible by norm_groups" + ) + + self.out_block = nn.Sequential( + Fp32GroupNorm( + num_groups=norm_groups, num_channels=mid_channels, eps=norm_eps + ), + nn.SiLU(), + nn.Conv2d( + mid_channels, + out_channels=2 * z_channels if double_z else z_channels, + kernel_size=3, + padding=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + h = self.init_conv(x) + h = self.down_block(h) + h = self.mid_block(h) + h = self.out_block(h) + return h + + +class ResNetDecoder(nn.Module): + """Resnet decoder used in the LDM Autoencoder that consists of a init convolution, + middle resnet blocks with attention, upsamling resnet blocks and output convolution + block with group normalization and nonlinearity. Optionally, also supports alpha + channel in output. + + Follows the architecture described in "High-Resolution Image Synthesis with Latent + Diffusion Models" (https://arxiv.org/abs/2112.10752) + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/diffusionmodules/model.py#L462 + + Attributes: + out_channels (int): number of channels in output image. + z_channels (int): number of latent channels. + channels (int): number of channels to be used with channel multipliers. + num_res_block (int): number of residual blocks at each resolution. + channel_multipliers (Sequence[int]): list of channel multipliers used by the encoder. + Decoder uses them in reverse order. Defaults to [1, 2, 4, 8]. + dropout (float): dropout probability. Defaults to 0.0. + norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. + norm_eps (float): epsilon used in the GroupNorm layer. Defaults to 1e-6. + output_alpha_channel (bool): whether to include an alpha channel in the output. + Defaults to False. + + Args: + z (Tensor): input Tensor of shape [b, c, h, w] + + Raises: + ValueError: If `channels` * `channel_multipliers[-1]` is not divisible by `norm_groups`. + """ + + def __init__( + self, + out_channels: int, + z_channels: int, + channels: int, + num_res_blocks: int, + channel_multipliers: Sequence[int] = ( + 1, + 2, + 4, + 8, + ), + dropout: float = 0.0, + norm_groups: int = 32, + norm_eps: float = 1e-6, + output_alpha_channel: bool = False, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + self.output_alpha_channel = output_alpha_channel + + channels_list = tuple( + reversed( + [ + channels * multiplier + for multiplier in list(channel_multipliers) + + [channel_multipliers[-1]] + ] + ) + ) + mid_channels = channels_list[0] + + # initial convolution + self.init_conv = nn.Conv2d(z_channels, mid_channels, kernel_size=3, padding=1) + + # middle block + self.mid_block = nn.Sequential( + res_block(mid_channels, mid_channels, dropout, norm_groups, norm_eps), + attention_res_block(mid_channels, norm_groups, norm_eps), + res_block(mid_channels, mid_channels, dropout, norm_groups, norm_eps), + ) + + # upsample block + self.up_block = nn.Sequential() + num_resolutions = len(channel_multipliers) + for level_idx in range(num_resolutions): + block_in = channels_list[level_idx] + block_out = channels_list[level_idx + 1] + self.up_block.append( + res_block_stack( + block_in, + block_out, + # decoder creates 1 additional res block compared to encoder. + # not sure about intuition, but seems to be used everywhere in OSS. + num_res_blocks + 1, + dropout, + needs_upsample=True if level_idx != num_resolutions - 1 else False, + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + ) + + # output nonlinearity block + post_upsample_channels = channels_list[-1] + if post_upsample_channels % norm_groups != 0: + raise ValueError( + "Channel dims obtained by multiplying channels with first" + " item in channel_multipliers needs to be divisible by norm_groups" + ) + self.out_nonlinearity_block = nn.Sequential( + Fp32GroupNorm( + num_groups=norm_groups, + num_channels=post_upsample_channels, + eps=norm_eps, + ), + nn.SiLU(), + ) + + # output projections + self.conv_out = nn.Conv2d( + post_upsample_channels, out_channels, kernel_size=3, padding=1 + ) + if self.output_alpha_channel: + self.alpha_conv_out = nn.Conv2d( + post_upsample_channels, 1, kernel_size=3, padding=1 + ) + + def forward(self, z: Tensor) -> Tensor: + h = self.init_conv(z) + h = self.mid_block(h) + h = self.up_block(h) + h = self.out_nonlinearity_block(h) + + # If alpha channel is required as output, compute it separately with its + # own conv layer and concatenate with the output from the out convolulution + if self.output_alpha_channel: + h = torch.cat((self.conv_out(h), self.alpha_conv_out(h)), dim=1) + else: + h = self.conv_out(h) + + return h + + +def res_block_stack( + in_channels: int, + out_channels: int, + num_blocks: int, + dropout: float = 0.0, + needs_upsample: bool = False, + needs_downsample: bool = False, + norm_groups: int = 32, + norm_eps: float = 1e-6, +) -> nn.Module: + if needs_upsample and needs_downsample: + raise ValueError("Cannot use both upsample and downsample in res block") + block_in, block_out = in_channels, out_channels + block_stack = nn.Sequential() + for _ in range(num_blocks): + block_stack.append( + res_block(block_in, block_out, dropout, norm_groups, norm_eps) + ) + block_in = block_out + if needs_downsample: + block_stack.append(Downsample2D(out_channels)) + if needs_upsample: + block_stack.append(Upsample2D(out_channels)) + return block_stack + + +def res_block( + in_channels: int, + out_channels: int, + dropout: float = 0.0, + norm_groups: int = 32, + norm_eps: float = 1e-6, +) -> ResBlock: + res_block_partial = partial( + ResBlock, + in_channels=in_channels, + out_channels=out_channels, + pre_outconv_dropout=dropout, + scale_shift_conditional=False, + norm_groups=norm_groups, + norm_eps=norm_eps, + ) + if in_channels != out_channels: + return res_block_partial( + skip_conv=nn.Conv2d(in_channels, out_channels, kernel_size=1) + ) + else: + return res_block_partial() diff --git a/torchmultimodal/diffusion_labs/models/vae/res_block.py b/torchmultimodal/diffusion_labs/models/vae/res_block.py new file mode 100644 index 00000000..02764d20 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/vae/res_block.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor +from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm + + +class ResBlock(nn.Module): + """Residual block in the ADM net. Supports projecting a conditional embedding to add to the hidden state. + This typically contains the timestep embedding, but can also contain class embedding for classifier free guidance, + CLIP image embedding and text encoder output for text-to-image generation as in DALL-E 2, or anything you want to + condition the diffusion model on. If conditional embedding is not passed, the hidden state is simply passed through. + + Follows the architecture described in "Diffusion Models Beat GANs on Image Synthesis" + (https://arxiv.org/abs/2105.05233) and BigGAN residual blocks (https://arxiv.org/abs/1809.11096). + + Code ref: + https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/unet.py#L143 + + + Attributes: + in_channels (int): num channels expected in input. Needs to be divisible by norm_groups. + out_channels (int): num channels desired in output. Needs to be divisible by norm_groups. + use_upsample (bool): include nn.Upsample layer before first conv on hidden state and on skip connection. + Defaults to False. Cannot be True if use_downsample is True. + use_downsample (bool): include nn.AvgPool2d layer before first conv on hidden state and on skip connection. + Defaults to False. Cannot be True if use_upsample is True. + activation (nn.Module): activation used before convs. Defaults to nn.SiLU(). + skip_conv (nn.Module): module used for additional convolution on skip connection. Defaults to nn.Identity(). + cond_proj (Optional[nn.Module]): module used for conditional embedding projection. Defaults to None. + rescale_skip_connection (bool): whether to rescale skip connection by 1/sqrt(2), as described in "Diffusion + Models Beat GANs on Image Synthesis" (https://arxiv.org/abs/2105.05233). Defaults to False. + scale_shift_conditional (bool): if True, splits conditional embedding into two separate projections, + and adds to hidden state as Norm(h)(w + 1) + b, as described in Appendix A in + "Improved Denoising Diffusion Probabilistic Models" (https://arxiv.org/abs/2102.09672). + Defaults to True. + pre_outconv_dropout (float): dropout probability before the second conv. Defaults to 0.1. + norm_groups (int): number of groups used in GroupNorm layer. Defaults to 32. + norm_eps (float): Epsilon used in the GroupNorm layer. Defaults to 1e-5. + + + Args: + x (Tensor): input Tensor of shape [B x C x H x W] + conditional_embedding (Tensor, optional): conditioning embedding vector of shape [B x C]. + If None, hidden state is passed through. + + Raises: + TypeError: When skip_conv is not defined and in_channels != out_channels. + TypeError: When use_upsample and use_downsample are both True. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + use_upsample: bool = False, + use_downsample: bool = False, + activation: nn.Module = nn.SiLU(), + skip_conv: nn.Module = nn.Identity(), + cond_proj: Optional[nn.Module] = None, + rescale_skip_connection: bool = False, + scale_shift_conditional: bool = True, + pre_outconv_dropout: float = 0.1, + norm_groups: int = 32, + norm_eps: float = 1e-05, + ): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + if isinstance(skip_conv, nn.Identity) and in_channels != out_channels: + raise ValueError( + "You must specify a skip connection conv if out_channels != in_channels" + ) + + if in_channels % norm_groups != 0 or out_channels % norm_groups != 0: + raise ValueError("Channel dims need to be divisible by norm_groups") + + if use_downsample and use_upsample: + raise ValueError("Cannot use both upsample and downsample in res block") + elif use_downsample: + hidden_updownsample_layer = nn.AvgPool2d(kernel_size=2, stride=2) + skip_updownsample_layer = nn.AvgPool2d(kernel_size=2, stride=2) + elif use_upsample: + hidden_updownsample_layer = nn.Upsample(scale_factor=2, mode="nearest") + skip_updownsample_layer = nn.Upsample(scale_factor=2, mode="nearest") + else: + hidden_updownsample_layer = nn.Identity() + skip_updownsample_layer = nn.Identity() + + self.cond_proj = cond_proj + self.in_block = nn.Sequential( + Fp32GroupNorm( + norm_groups, in_channels, eps=norm_eps + ), # groups = 32 from code ref + activation, + hidden_updownsample_layer, + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + ) + + self.out_group_norm = Fp32GroupNorm(norm_groups, out_channels, eps=norm_eps) + self.out_block = nn.Sequential( + activation, + nn.Dropout(pre_outconv_dropout), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip_block = nn.Sequential( + skip_updownsample_layer, + skip_conv, + ) + + self.scale_shift_conditional = scale_shift_conditional + self.rescale_skip_connection = rescale_skip_connection + + def forward( + self, + x: Tensor, + conditional_embedding: Optional[Tensor] = None, + ) -> Tensor: + skip = self.skip_block(x) + h = self.in_block(x) + + # Add conditional embedding to h, if they are passed and cond_proj is defined + if conditional_embedding is not None and self.cond_proj is not None: + t = self.cond_proj(conditional_embedding) + # [b, c] -> [b, c, 1, 1] + t = t.unsqueeze(-1).unsqueeze(-1) + + # If specified, split conditional embedding into two separate projections. + # Use half to multiply with hidden state and half to add. + # This is typically done after normalization. + if self.scale_shift_conditional: + h = self.out_group_norm(h) + scale, shift = torch.chunk(t, 2, dim=1) + h = h * (1 + scale) + shift + h = self.out_block(h) + else: + h = self.out_block(self.out_group_norm(h + t)) + else: + h = self.out_block(self.out_group_norm(h)) + + if self.rescale_skip_connection: + h = (skip + h) / 1.414 + else: + h = skip + h + return h + + +def adm_res_block( + in_channels: int, + out_channels: int, + dim_cond: int, + rescale_skip_connection: bool = False, +) -> ResBlock: + if in_channels != out_channels: + return adm_res_skipconv_block(in_channels, out_channels, dim_cond) + return ResBlock( + in_channels=in_channels, + out_channels=out_channels, + rescale_skip_connection=rescale_skip_connection, + cond_proj=adm_cond_proj(dim_cond, out_channels), + ) + + +def adm_res_downsample_block( + num_channels: int, + dim_cond: int, + rescale_skip_connection: bool = False, +) -> ResBlock: + return ResBlock( + in_channels=num_channels, + out_channels=num_channels, + use_downsample=True, + rescale_skip_connection=rescale_skip_connection, + cond_proj=adm_cond_proj(dim_cond, num_channels), + ) + + +def adm_res_skipconv_block( + in_channels: int, + out_channels: int, + dim_cond: int, + rescale_skip_connection: bool = False, +) -> ResBlock: + skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) + return ResBlock( + in_channels=in_channels, + out_channels=out_channels, + skip_conv=skip_conv, + rescale_skip_connection=rescale_skip_connection, + cond_proj=adm_cond_proj(dim_cond, out_channels), + ) + + +def adm_res_upsample_block( + num_channels: int, + dim_cond: int, + rescale_skip_connection: bool = False, +) -> ResBlock: + return ResBlock( + in_channels=num_channels, + out_channels=num_channels, + use_upsample=True, + rescale_skip_connection=rescale_skip_connection, + cond_proj=adm_cond_proj(dim_cond, num_channels), + ) + + +def adm_cond_proj( + dim_cond: int, + cond_channels: int, + scale_shift_conditional: bool = True, +) -> nn.Module: + if scale_shift_conditional: + cond_channels *= 2 + return nn.Sequential(nn.SiLU(), nn.Linear(dim_cond, cond_channels)) diff --git a/torchmultimodal/diffusion_labs/models/vae/residual_sampling.py b/torchmultimodal/diffusion_labs/models/vae/residual_sampling.py new file mode 100644 index 00000000..f6e310f0 --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/vae/residual_sampling.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn as nn, Tensor +from torch.nn import functional as F + + +class Upsample2D(nn.Module): + """2-Dimensional upsampling layer with nearest neighbor interpolation and + 2D convolution, used for image decoders. + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L91 + + Attributes: + channels (int): Number of channels in the input. + + Args: + x (Tensor): 2-D image input tensor with shape (n, c, h, w). + """ + + def __init__(self, channels: int): + super().__init__() + self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(F.interpolate(x, scale_factor=2, mode="nearest")) + + +class Downsample2D(nn.Module): + """2-Dimensional downsampling layer with zero padding and 2D convolution, + used for image encoders. + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L134 + + Attributes: + channels (int): Number of channels in the input. + asymmetric_padding (bool): Whether to use asymmetric padding. + Defaults to True. + + Args: + x (Tensor): 2-D image input tensor with shape (n, c, h, w). + """ + + def __init__( + self, + channels: int, + asymmetric_padding: bool = True, + ): + super().__init__() + if asymmetric_padding: + padding = (0, 1, 0, 1) + else: + padding = 1 + self.op = nn.Sequential( + nn.ZeroPad2d(padding), + nn.Conv2d(channels, channels, kernel_size=3, stride=2), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.op(x) diff --git a/torchmultimodal/diffusion_labs/models/vae/vae.py b/torchmultimodal/diffusion_labs/models/vae/vae.py new file mode 100644 index 00000000..878a675e --- /dev/null +++ b/torchmultimodal/diffusion_labs/models/vae/vae.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +from typing import NamedTuple, Sequence + +import torch +from torch import nn, Tensor +from torch.distributions import Distribution, Normal +from torchmultimodal.diffusion_labs.models.vae.encoder_decoder import ( + ResNetDecoder, + ResNetEncoder, +) + + +class VAEOutput(NamedTuple): + posterior: Distribution + decoder_output: Tensor + + +class VariationalAutoencoder(nn.Module): + """Variational Autoencoder (https://arxiv.org/abs/1906.02691) is a special type of autoencoder + where the encoder outputs the the parameters of the posterior latent distribution instead of + outputting fixed vectors in the latent space. The decoder consumes a sample from the latent + distribution to reconstruct the inputs. + + Follows the architecture used in "High-Resolution Image Synthesis with Latent + Diffusion Models" (https://arxiv.org/abs/2112.10752) + + Code ref: + https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/autoencoder.py#L285 + + Attributes: + encoder (nn.Module): instance of encoder module. + decoder (nn.Module): instance of decoder module. + + Args: + x (Tensor): input Tensor of shape [b, c, h, w] + sample_posterior (bool): if True, sample from posterior instead of distribution mpde. + Defaults to True. + """ + + def __init__(self, encoder: nn.Module, decoder: nn.Module): + super().__init__() + torch._C._log_api_usage_once(f"torchmultimodal.{self.__class__.__name__}") + + self.encoder = encoder + self.decoder = decoder + + def encode(self, x: Tensor) -> Distribution: + h = self.encoder(x) + # output of encoder is mean and log variaance of a normal distribution + mean, log_variance = torch.chunk(h, 2, dim=1) + # clamp logvariance to [-30. 20] + log_variance = torch.clamp(log_variance, -30.0, 20.0) + stddev = torch.exp(log_variance / 2.0) + posterior = Normal(mean, stddev) + return posterior + + def decode(self, z: Tensor) -> Tensor: + return self.decoder(z) + + def forward(self, x: Tensor, sample_posterior: bool = True) -> VAEOutput: + posterior = self.encode(x) + if sample_posterior: + z = posterior.rsample() + else: + z = posterior.mode + decoder_out = self.decode(z) + return VAEOutput(posterior=posterior, decoder_output=decoder_out) + + +def variational_autoencoder( + *, + embedding_channels: int, + in_channels: int, + out_channels: int, + z_channels, + channels: int, + num_res_blocks: int, + channel_multipliers: Sequence[int] = (1, 2, 4, 8), + dropout: float = 0.0, + norm_groups: int = 32, + norm_eps: float = 1e-6, + output_alpha_channel: bool = False, +): + encoder = nn.Sequential( + # pyre-ignore + OrderedDict( + [ + ( + "resnet_encoder", + ResNetEncoder( + in_channels=in_channels, + z_channels=z_channels, + channels=channels, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + dropout=dropout, + norm_groups=norm_groups, + norm_eps=norm_eps, + double_z=True, + ), + ), + ( + "quant_conv", + nn.Conv2d(2 * z_channels, 2 * embedding_channels, kernel_size=1), + ), + ] + ) + ) + + decoder = nn.Sequential( + # pyre-ignore + OrderedDict( + [ + ( + "post_quant_conv", + nn.Conv2d(embedding_channels, z_channels, kernel_size=1), + ), + ( + "resnet_decoder", + ResNetDecoder( + out_channels=out_channels, + z_channels=z_channels, + channels=channels, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + dropout=dropout, + norm_groups=norm_groups, + norm_eps=norm_eps, + output_alpha_channel=output_alpha_channel, + ), + ), + ] + ) + ) + + return VariationalAutoencoder(encoder=encoder, decoder=decoder)