From e815ebb4a259b12a094dbae45f4302e22dd2289b Mon Sep 17 00:00:00 2001 From: Hanzi Mao Date: Tue, 20 Jul 2021 12:26:49 -0700 Subject: [PATCH] Optional shared encoder for the two tower model. Summary: Provide shared encoder option for the two tower model. Reviewed By: sinonwang Differential Revision: D29557881 fbshipit-source-id: 501b1d2304326f6d3f05b1c37ca494d0727c8547 --- .../models/two_tower_classification_model.py | 52 ++++++++++++++----- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/pytext/models/two_tower_classification_model.py b/pytext/models/two_tower_classification_model.py index 225b3f070..832c896ce 100644 --- a/pytext/models/two_tower_classification_model.py +++ b/pytext/models/two_tower_classification_model.py @@ -46,6 +46,7 @@ class InputConfig(ConfigBase): output_layer: ClassificationOutputLayer.Config = ( ClassificationOutputLayer.Config() ) + use_shared_encoder: bool = False def trace(self, inputs): return torch.jit.trace(self, inputs) @@ -113,16 +114,25 @@ def forward( left_encoder_inputs: Tuple[torch.Tensor, ...], *args ) -> List[torch.Tensor]: - if self.right_encoder.output_encoded_layers: - # if encoded layers are returned, discard them - right_representation = self.right_encoder(right_encoder_inputs)[1] + if self.use_shared_encoder: + if self.right_encoder.output_encoded_layers: + # if encoded layers are returned, discard them + right_representation = self.right_encoder(right_encoder_inputs)[1] + left_representation = self.right_encoder(left_encoder_inputs)[1] + else: + right_representation = self.right_encoder(right_encoder_inputs)[0] + left_representation = self.right_encoder(left_encoder_inputs)[0] else: - right_representation = self.right_encoder(right_encoder_inputs)[0] - if self.left_encoder.output_encoded_layers: - # if encoded layers are returned, discard them - left_representation = self.left_encoder(left_encoder_inputs)[1] - else: - left_representation = self.left_encoder(left_encoder_inputs)[0] + if self.right_encoder.output_encoded_layers: + # if encoded layers are returned, discard them + right_representation = self.right_encoder(right_encoder_inputs)[1] + else: + right_representation = self.right_encoder(right_encoder_inputs)[0] + if self.left_encoder.output_encoded_layers: + # if encoded layers are returned, discard them + left_representation = self.left_encoder(left_encoder_inputs)[1] + else: + left_representation = self.left_encoder(left_encoder_inputs)[0] return self.decoder(right_representation, left_representation, *args) def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None): @@ -173,16 +183,32 @@ def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]): output_layer_cls = MulticlassOutputLayer output_layer = output_layer_cls(list(labels), loss) - return cls(right_encoder, left_encoder, decoder, output_layer) + return cls( + right_encoder, + left_encoder, + decoder, + output_layer, + config.use_shared_encoder, + ) def __init__( - self, right_encoder, left_encoder, decoder, output_layer, stage=Stage.TRAIN + self, + right_encoder, + left_encoder, + decoder, + output_layer, + use_shared_encoder, + stage=Stage.TRAIN, ) -> None: super().__init__(stage=stage) self.right_encoder = right_encoder - self.left_encoder = left_encoder + self.use_shared_encoder = use_shared_encoder self.decoder = decoder - self.module_list = [right_encoder, left_encoder, decoder] + if self.use_shared_encoder: + self.module_list = [right_encoder, decoder] + else: + self.left_encoder = left_encoder + self.module_list = [right_encoder, left_encoder, decoder] self.output_layer = output_layer self.stage = stage log_class_usage(__class__)