Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Optional shared encoder for the two tower model.
Browse files Browse the repository at this point in the history
Summary: Provide shared encoder option for the two tower model.

Reviewed By: sinonwang

Differential Revision: D29557881

fbshipit-source-id: 501b1d2304326f6d3f05b1c37ca494d0727c8547
  • Loading branch information
HannaMao authored and facebook-github-bot committed Jul 20, 2021
1 parent 2005e82 commit e815ebb
Showing 1 changed file with 39 additions and 13 deletions.
52 changes: 39 additions & 13 deletions pytext/models/two_tower_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class InputConfig(ConfigBase):
output_layer: ClassificationOutputLayer.Config = (
ClassificationOutputLayer.Config()
)
use_shared_encoder: bool = False

def trace(self, inputs):
return torch.jit.trace(self, inputs)
Expand Down Expand Up @@ -113,16 +114,25 @@ def forward(
left_encoder_inputs: Tuple[torch.Tensor, ...],
*args
) -> List[torch.Tensor]:
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
right_representation = self.right_encoder(right_encoder_inputs)[1]
if self.use_shared_encoder:
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
right_representation = self.right_encoder(right_encoder_inputs)[1]
left_representation = self.right_encoder(left_encoder_inputs)[1]
else:
right_representation = self.right_encoder(right_encoder_inputs)[0]
left_representation = self.right_encoder(left_encoder_inputs)[0]
else:
right_representation = self.right_encoder(right_encoder_inputs)[0]
if self.left_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
left_representation = self.left_encoder(left_encoder_inputs)[1]
else:
left_representation = self.left_encoder(left_encoder_inputs)[0]
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
right_representation = self.right_encoder(right_encoder_inputs)[1]
else:
right_representation = self.right_encoder(right_encoder_inputs)[0]
if self.left_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
left_representation = self.left_encoder(left_encoder_inputs)[1]
else:
left_representation = self.left_encoder(left_encoder_inputs)[0]
return self.decoder(right_representation, left_representation, *args)

def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
Expand Down Expand Up @@ -173,16 +183,32 @@ def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
output_layer_cls = MulticlassOutputLayer

output_layer = output_layer_cls(list(labels), loss)
return cls(right_encoder, left_encoder, decoder, output_layer)
return cls(
right_encoder,
left_encoder,
decoder,
output_layer,
config.use_shared_encoder,
)

def __init__(
self, right_encoder, left_encoder, decoder, output_layer, stage=Stage.TRAIN
self,
right_encoder,
left_encoder,
decoder,
output_layer,
use_shared_encoder,
stage=Stage.TRAIN,
) -> None:
super().__init__(stage=stage)
self.right_encoder = right_encoder
self.left_encoder = left_encoder
self.use_shared_encoder = use_shared_encoder
self.decoder = decoder
self.module_list = [right_encoder, left_encoder, decoder]
if self.use_shared_encoder:
self.module_list = [right_encoder, decoder]
else:
self.left_encoder = left_encoder
self.module_list = [right_encoder, left_encoder, decoder]
self.output_layer = output_layer
self.stage = stage
log_class_usage(__class__)

0 comments on commit e815ebb

Please sign in to comment.