Skip to content

Commit

Permalink
Fixes for CoCa cascaded attention poolers (#518)
Browse files Browse the repository at this point in the history
Summary:
A couple fixes to CoCa's attention pooling as pointed out in #517. Specifically, we need to change the input dim for the contrastive pooler to match the output dim from the captioning pooler in the case of cascaded attention pooling. We should also set `n_queries=1` for the contrastive pooler so that the pooled embeddings can be directly fed into contrastive loss (after appropriate normalization).

Pull Request resolved: #518

Test Plan:
```
from torchmultimodal.models.coca.coca_model import coca_vit_l_14
model = coca_vit_l_14()
bs, c, h, w, seq_len, vocab_size = 2, 3, 224, 224, 77, 49408
images = torch.randn(bs, c, h, w)
texts = torch.randint(0, vocab_size, (bs, seq_len))
out = model(images, texts)
print(out.image_pooled_output.shape, out.multimodal_embeddings.shape)
...
torch.Size([2, 1, 768]) torch.Size([2, 76, 49408])
```

Add new unit test:

```
python -m pytest -v tests/models/coca/test_coca_model.py
...
===== 4 passed in 3.18s ======
```

Reviewed By: pbontrager

Differential Revision: D52523771

Pulled By: ebsmothers

fbshipit-source-id: 7c0197605e478ae6e3204f1ec0ab2e6adbf2377e
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Jan 4, 2024
1 parent fc92cea commit 63c629a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 33 deletions.
81 changes: 52 additions & 29 deletions tests/models/coca/test_coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,42 @@ def image_size(self):
return 12

@pytest.fixture
def coca_model(
def get_coca_model(
self,
vocab_size,
num_text_positions,
attention_pooler_output_dim,
text_output_dim,
image_size,
):
coca_model = coca_vit(
vision_patch_size=4,
vision_dim_feedforward=24,
vision_n_layer=2,
vision_n_head=2,
vocab_size=vocab_size,
num_text_positions=num_text_positions,
text_hidden_dim=8,
text_n_layer=2,
text_n_head=2,
text_dim_feedforward=32,
text_output_dim=text_output_dim,
fusion_n_layer=2,
fusion_n_head=2,
fusion_dim_feedforward=32,
multimodal_output_projection_dim=vocab_size,
pooler_input_embed_dim=6,
pooler_output_embed_dim=attention_pooler_output_dim,
image_size=image_size,
pooler_n_head=2,
cascaded_pooler=False,
)
init_weights_with_constant(coca_model)
coca_model.eval()
return coca_model
def create_coca_model(cascaded_pooler: bool = False):
coca_model = coca_vit(
vision_patch_size=4,
vision_dim_feedforward=24,
vision_n_layer=2,
vision_n_head=2,
vocab_size=vocab_size,
num_text_positions=num_text_positions,
text_hidden_dim=8,
text_n_layer=2,
text_n_head=2,
text_dim_feedforward=32,
text_output_dim=text_output_dim,
fusion_n_layer=2,
fusion_n_head=2,
fusion_dim_feedforward=32,
multimodal_output_projection_dim=vocab_size,
pooler_input_embed_dim=6,
pooler_output_embed_dim=attention_pooler_output_dim,
image_size=image_size,
pooler_n_head=2,
cascaded_pooler=cascaded_pooler,
)
init_weights_with_constant(coca_model)
coca_model.eval()
return coca_model

return create_coca_model

@pytest.fixture
def text_inputs(self):
Expand Down Expand Up @@ -111,17 +114,37 @@ def expected(
)

@pytest.fixture
def coca_for_pretraining(self, coca_model):
def coca_for_pretraining(self, get_coca_model):
coca_model = get_coca_model()
coca_for_pretraining = CoCaForPretraining(coca_model)
init_weights_with_constant(coca_for_pretraining)
coca_for_pretraining.eval()
return coca_for_pretraining

def test_coca_model(self, text_inputs, image_inputs, coca_model, expected):
def test_coca_model(self, text_inputs, image_inputs, get_coca_model, expected):
coca_model = get_coca_model()
actual = coca_model(image_inputs, text_inputs)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_scripting(self, text_inputs, image_inputs, coca_model):
def test_coca_model_cascaded_pooler(
self,
text_inputs,
image_inputs,
get_coca_model,
batch_size,
attention_pooler_output_dim,
):
coca_model_cascaded_pooler = get_coca_model(cascaded_pooler=True)
actual = coca_model_cascaded_pooler(image_inputs, text_inputs)
assert_expected(
actual.image_pooled_output.shape,
(batch_size, 1, attention_pooler_output_dim),
rtol=0,
atol=1e-4,
)

def test_scripting(self, text_inputs, image_inputs, get_coca_model):
coca_model = get_coca_model()
scripted_model = torch.jit.script(coca_model)
assert_expected(
scripted_model(image_inputs, text_inputs),
Expand Down
13 changes: 9 additions & 4 deletions torchmultimodal/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ def coca_vit(
multimodal embeddings. Default: None
cascaded_pooler (bool): Whether to cascade (stack) contrastive and captioning
attention poolers or parallelize them. Default: True
pooler_n_queries (int): Number of queries in attention pooler. Default: 256
pooler_n_queries (int): Number of queries in captioning attention pooler.
Contrastive attention pooler always has one query. For parallel pooling,
the attention pooler will have a single pooler using n_queries+1 queries
with the first position for contrastive embeddings. For cascaded pooling,
the first pooler is the captioning pooler with pooler_n_queries queries
and the second is the contrastive pooler with one query. Default: 256
pooler_layer_norm_eps (float): LN epsilon in attention pooler. Default: 1e-5
"""
attention_pooler: nn.Module
Expand All @@ -258,10 +263,10 @@ def coca_vit(
layer_norm_eps=pooler_layer_norm_eps,
)
contrastive_pooler = AttentionPooler(
input_embed_dim=pooler_input_embed_dim,
input_embed_dim=pooler_output_embed_dim,
output_embed_dim=pooler_output_embed_dim,
n_head=pooler_n_head,
n_queries=pooler_n_queries,
n_queries=1,
layer_norm_eps=pooler_layer_norm_eps,
)
attention_pooler = CascadedAttentionPooler(
Expand All @@ -272,7 +277,7 @@ def coca_vit(
input_embed_dim=pooler_input_embed_dim,
output_embed_dim=pooler_output_embed_dim,
n_head=pooler_n_head,
n_queries=pooler_n_queries,
n_queries=pooler_n_queries + 1,
layer_norm_eps=pooler_layer_norm_eps,
)

Expand Down

0 comments on commit 63c629a

Please sign in to comment.